mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
15 Commits
serve-quan
...
v4.50.3
Author | SHA1 | Date | |
---|---|---|---|
a78e884f31 | |||
e9a5e32b76 | |||
556b96d2e7 | |||
f7ba365881 | |||
b258dc35d5 | |||
f4cfe5df33 | |||
cfef91d802 | |||
6311953dd4 | |||
897130524b | |||
d9ccb9adbb | |||
e6ab93e702 | |||
650f607840 | |||
9abbb92297 | |||
0b057e66b5 | |||
26fbd6919a |
@ -105,59 +105,75 @@ inputs = processor.apply_chat_template(
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||
print(processor.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
### Multi-image Inference
|
||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||
|
||||
```python
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
||||
|
||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user", "content": [
|
||||
{"type": "image", "url": url_cow},
|
||||
{"type": "image", "url": url_stop},
|
||||
{"type": "text", "text": "Are these two images identical?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
).to(model.device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=50)
|
||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
||||
```py
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
|
||||
visualizer("<img>What is shown in this image?")
|
||||
```
|
||||
|
||||
### Text-only inference
|
||||
## Notes
|
||||
|
||||
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
|
||||
```python
|
||||
from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
- Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs.
|
||||
- Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
|
||||
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
```py
|
||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
messages =[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": url_cow},
|
||||
{"type": "image", "url": url_cat},
|
||||
{"type": "text", "text": "Which image is cuter?"},
|
||||
]
|
||||
},
|
||||
]
|
||||
```
|
||||
- Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
|
||||
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
|
||||
- By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.
|
||||
|
||||
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
|
||||
```diff
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
add_generation_prompt=True,
|
||||
+ do_pan_and_scan=True,
|
||||
).to("cuda")
|
||||
```
|
||||
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/gemma-3-1b-pt",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-3-1b-pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=100)
|
||||
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.50.0.dev0")
|
||||
check_min_version("4.50.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -446,7 +446,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.50.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.50.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.50.0.dev0"
|
||||
__version__ = "4.50.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -585,7 +585,7 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
|
||||
return torch.utils._pytree.tree_flatten(dictionary)[0]
|
||||
|
||||
|
||||
if is_torch_greater_or_equal("2.2"):
|
||||
if is_torch_greater_or_equal("2.3"):
|
||||
torch.utils._pytree.register_pytree_node(
|
||||
DynamicCache,
|
||||
_flatten_dynamic_cache,
|
||||
@ -611,21 +611,29 @@ class OffloadedCache(DynamicCache):
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())):
|
||||
if not (
|
||||
torch.cuda.is_available()
|
||||
or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available())
|
||||
):
|
||||
raise RuntimeError(
|
||||
"OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "")
|
||||
"OffloadedCache can only be used with a GPU"
|
||||
+ (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "")
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.original_device = []
|
||||
self.prefetch_stream = None
|
||||
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream()
|
||||
self.prefetch_stream = (
|
||||
torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream()
|
||||
)
|
||||
self.beam_idx = None # used to delay beam search operations
|
||||
|
||||
def prefetch_layer(self, layer_idx: int):
|
||||
"Starts prefetching the next layer cache"
|
||||
if layer_idx < len(self):
|
||||
with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream):
|
||||
with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream(
|
||||
self.prefetch_stream
|
||||
):
|
||||
# Prefetch next layer tensors to GPU
|
||||
device = self.original_device[layer_idx]
|
||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
|
||||
@ -643,7 +651,7 @@ class OffloadedCache(DynamicCache):
|
||||
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
|
||||
if layer_idx < len(self):
|
||||
# Evict the previous layer if necessary
|
||||
if is_torch_greater_or_equal("2.7"):
|
||||
if is_torch_greater_or_equal("2.7", accept_dev=True):
|
||||
torch.accelerator.current_stream().synchronize()
|
||||
else:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
@ -1122,7 +1122,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
If `decoder` is set to `True`, then only search for decoder config names.
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
encoder_possible_text_config_names = ("text_encoder",)
|
||||
@ -1144,8 +1146,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||
)
|
||||
elif len(valid_text_config_names) == 1:
|
||||
return getattr(self, valid_text_config_names[0])
|
||||
return self
|
||||
config_to_return = getattr(self, valid_text_config_names[0])
|
||||
else:
|
||||
config_to_return = self
|
||||
return config_to_return
|
||||
|
||||
|
||||
def get_configuration_file(configuration_files: List[str]) -> str:
|
||||
|
@ -3887,9 +3887,14 @@ class GenerationMixin:
|
||||
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
|
||||
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
|
||||
|
||||
# Crop the static-shaped tensors to the actual size
|
||||
sequences = sequences[:, :cur_len]
|
||||
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
|
||||
# Crop the static-shaped tensors to the actual size.
|
||||
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
|
||||
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
|
||||
# previous decoding iteration)
|
||||
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
|
||||
output_length = decoder_prompt_len + max_generated_length
|
||||
sequences = sequences[:, :output_length]
|
||||
beam_indices = beam_indices[:, :max_generated_length]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
|
@ -72,6 +72,8 @@ if is_vision_available():
|
||||
PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
|
||||
PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
|
||||
}
|
||||
else:
|
||||
pil_torch_interpolation_mapping = {}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = AlbertConfig.from_json_file(albert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = AlbertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--albert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ALBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
|
@ -1,389 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ALIGN checkpoints from the original repository."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import align
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
from transformers import (
|
||||
AlignConfig,
|
||||
AlignModel,
|
||||
AlignProcessor,
|
||||
BertConfig,
|
||||
BertTokenizer,
|
||||
EfficientNetConfig,
|
||||
EfficientNetImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
image = tf.image.resize(image, (346, 346))
|
||||
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
|
||||
return image
|
||||
|
||||
|
||||
def get_align_config():
|
||||
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
|
||||
vision_config.image_size = 289
|
||||
vision_config.hidden_dim = 640
|
||||
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
|
||||
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
|
||||
vision_config.depthwise_padding = []
|
||||
|
||||
text_config = BertConfig()
|
||||
config = AlignConfig.from_text_vision_configs(
|
||||
text_config=text_config, vision_config=vision_config, projection_dim=640
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def get_processor():
|
||||
image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
rescale_factor=1 / 127.5,
|
||||
rescale_offset=True,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
tokenizer.model_max_length = 64
|
||||
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def rename_keys(original_param_names):
|
||||
# EfficientNet image encoder
|
||||
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
|
||||
block_names = list(set(block_names))
|
||||
block_names = sorted(block_names)
|
||||
num_blocks = len(block_names)
|
||||
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
|
||||
|
||||
rename_keys = []
|
||||
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
|
||||
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
|
||||
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
|
||||
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
|
||||
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
|
||||
|
||||
for b in block_names:
|
||||
hf_b = block_name_mapping[b]
|
||||
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
|
||||
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
|
||||
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
|
||||
)
|
||||
|
||||
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
|
||||
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
|
||||
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
|
||||
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
|
||||
)
|
||||
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
|
||||
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
|
||||
)
|
||||
|
||||
key_mapping = {}
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = "vision_model." + item[1]
|
||||
|
||||
# BERT text encoder
|
||||
rename_keys = []
|
||||
old = "tf_bert_model/bert"
|
||||
new = "text_model"
|
||||
for i in range(12):
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.query.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.key.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.self.value.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
|
||||
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
|
||||
f"{new}.encoder.layer.{i}.intermediate.dense.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
|
||||
)
|
||||
|
||||
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
|
||||
)
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
|
||||
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
|
||||
|
||||
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
|
||||
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
|
||||
rename_keys.append(("dense/kernel:0", "text_projection.weight"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("dense/bias:0", "text_projection.bias"))
|
||||
rename_keys.append(("temperature:0", "temperature"))
|
||||
|
||||
for item in rename_keys:
|
||||
if item[0] in original_param_names:
|
||||
key_mapping[item[0]] = item[1]
|
||||
return key_mapping
|
||||
|
||||
|
||||
def replace_params(hf_params, tf_params, key_mapping):
|
||||
list(hf_params.keys())
|
||||
|
||||
for key, value in tf_params.items():
|
||||
if key not in key_mapping:
|
||||
continue
|
||||
|
||||
hf_key = key_mapping[key]
|
||||
if "_conv" in key and "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
|
||||
elif "embeddings" in key:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
elif "depthwise_kernel" in key:
|
||||
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
|
||||
elif "kernel" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value))
|
||||
elif "temperature" in key:
|
||||
new_hf_value = value
|
||||
elif "bn/gamma" or "bn/beta" in key:
|
||||
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
|
||||
else:
|
||||
new_hf_value = torch.from_numpy(value)
|
||||
|
||||
# Replace HF parameters with original TF model parameters
|
||||
hf_params[hf_key].copy_(new_hf_value)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ALIGN structure.
|
||||
"""
|
||||
# Load original model
|
||||
seq_length = 64
|
||||
tok = Tokenizer(seq_length)
|
||||
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
|
||||
original_model.compile()
|
||||
original_model.load_weights(checkpoint_path)
|
||||
|
||||
tf_params = original_model.trainable_variables
|
||||
tf_non_train_params = original_model.non_trainable_variables
|
||||
tf_params = {param.name: param.numpy() for param in tf_params}
|
||||
for param in tf_non_train_params:
|
||||
tf_params[param.name] = param.numpy()
|
||||
tf_param_names = list(tf_params.keys())
|
||||
|
||||
# Load HuggingFace model
|
||||
config = get_align_config()
|
||||
hf_model = AlignModel(config).eval()
|
||||
hf_params = hf_model.state_dict()
|
||||
|
||||
# Create src-to-dst parameter name mapping dictionary
|
||||
print("Converting parameters...")
|
||||
key_mapping = rename_keys(tf_param_names)
|
||||
replace_params(hf_params, tf_params, key_mapping)
|
||||
|
||||
# Initialize processor
|
||||
processor = get_processor()
|
||||
inputs = processor(
|
||||
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
|
||||
)
|
||||
|
||||
# HF model inference
|
||||
hf_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
|
||||
hf_image_features = outputs.image_embeds.detach().numpy()
|
||||
hf_text_features = outputs.text_embeds.detach().numpy()
|
||||
|
||||
# Original model inference
|
||||
original_model.trainable = False
|
||||
tf_image_processor = EfficientNetImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_rescale=False,
|
||||
do_normalize=False,
|
||||
include_top=False,
|
||||
resample=Image.BILINEAR,
|
||||
)
|
||||
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
|
||||
text = tok(tf.constant(["A picture of a cat"]))
|
||||
|
||||
image_features = original_model.image_encoder(image, training=False)
|
||||
text_features = original_model.text_encoder(text, training=False)
|
||||
|
||||
image_features = tf.nn.l2_normalize(image_features, axis=-1)
|
||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||
|
||||
# Check whether original and HF model outputs match -> np.allclose
|
||||
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||
raise ValueError("The predicted image features are not the same.")
|
||||
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||
raise ValueError("The predicted text features are not the same.")
|
||||
print("Model outputs match!")
|
||||
|
||||
if save_model:
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
# Save converted model and image processor
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Push model and image processor to hub
|
||||
print("Pushing converted ALIGN to the hub...")
|
||||
processor.push_to_hub("align-base")
|
||||
hf_model.push_to_hub("align-base")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="./weights/model-weights",
|
||||
type=str,
|
||||
help="Path to the pretrained TF ALIGN checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="hf_model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
|
@ -1,162 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AriaForConditionalGeneration,
|
||||
AriaProcessor,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
|
||||
EPILOG_TXT = """Example:
|
||||
python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
|
||||
|
||||
Example for creating the old state dict file with Python:
|
||||
|
||||
import torch
|
||||
from aria.model.language_model.aria_llama import AriaTextForCausalLM
|
||||
|
||||
# load model
|
||||
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
|
||||
model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
|
||||
|
||||
# load vision tower
|
||||
model.get_vision_tower().load_model()
|
||||
|
||||
# Save state dict
|
||||
torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
|
||||
"""
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"vision_tower.vision_model": "vision_tower",
|
||||
"ln_ffn": "layer_norm",
|
||||
"ffn": "feed_forward",
|
||||
"ln_kv": "layer_norm_kv",
|
||||
}
|
||||
|
||||
|
||||
def load_original_state_dict(model_id):
|
||||
directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_to_hf(state_dict):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key.endswith(".inv_freq"):
|
||||
continue
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
new_state_dict[key] = value
|
||||
new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
|
||||
new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_model_id,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|img|>",
|
||||
"pad_token": "<pad>",
|
||||
},
|
||||
)
|
||||
tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
|
||||
processor = AriaProcessor.from_pretrained(
|
||||
text_model_id,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(text_model_id)
|
||||
config.vision_config.hidden_size = 1152
|
||||
config.vision_config.attention_heads = 16
|
||||
config.pad_token_id = 2
|
||||
config.image_token_index = 9
|
||||
config.intermediate_size = config.moe_intermediate_size
|
||||
config.auto_map = {
|
||||
"AutoConfig": "modeling_aria.AriaConfig",
|
||||
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
|
||||
}
|
||||
|
||||
with torch.device("meta"):
|
||||
model = AriaForConditionalGeneration(config)
|
||||
|
||||
state_dict = load_original_state_dict(old_state_dict_id)
|
||||
|
||||
state_dict = convert_state_dict_to_hf(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
|
||||
# print("Saving models")
|
||||
# model.save_pretrained("local_aria", safe_serialization=False)
|
||||
# processor.save_pretrained("local_aria")
|
||||
print("Pushing to hub")
|
||||
model.push_to_hub(output_hub_path, create_pr=True)
|
||||
processor.push_to_hub(output_hub_path, create_pr=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog=EPILOG_TXT,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the text model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vision_model_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Hub location of the vision model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_hub_path",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the converted model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_state_dict_id",
|
||||
default="rhymes-ai/Aria",
|
||||
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,279 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_audio_spectrogram_transformer_config(model_name):
|
||||
config = ASTConfig()
|
||||
|
||||
if "10-10" in model_name:
|
||||
pass
|
||||
elif "speech-commands" in model_name:
|
||||
config.max_length = 128
|
||||
elif "12-12" in model_name:
|
||||
config.time_stride = 12
|
||||
config.frequency_stride = 12
|
||||
elif "14-14" in model_name:
|
||||
config.time_stride = 14
|
||||
config.frequency_stride = 14
|
||||
elif "16-16" in model_name:
|
||||
config.time_stride = 16
|
||||
config.frequency_stride = 16
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
if "speech-commands" in model_name:
|
||||
config.num_labels = 35
|
||||
filename = "speech-commands-v2-id2label.json"
|
||||
else:
|
||||
config.num_labels = 527
|
||||
filename = "audioset-id2label.json"
|
||||
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "module.v" in name:
|
||||
name = name.replace("module.v", "audio_spectrogram_transformer")
|
||||
if "cls_token" in name:
|
||||
name = name.replace("cls_token", "embeddings.cls_token")
|
||||
if "dist_token" in name:
|
||||
name = name.replace("dist_token", "embeddings.distillation_token")
|
||||
if "pos_embed" in name:
|
||||
name = name.replace("pos_embed", "embeddings.position_embeddings")
|
||||
if "patch_embed.proj" in name:
|
||||
name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
|
||||
# transformer blocks
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "encoder.layer")
|
||||
if "attn.proj" in name:
|
||||
name = name.replace("attn.proj", "attention.output.dense")
|
||||
if "attn" in name:
|
||||
name = name.replace("attn", "attention.self")
|
||||
if "norm1" in name:
|
||||
name = name.replace("norm1", "layernorm_before")
|
||||
if "norm2" in name:
|
||||
name = name.replace("norm2", "layernorm_after")
|
||||
if "mlp.fc1" in name:
|
||||
name = name.replace("mlp.fc1", "intermediate.dense")
|
||||
if "mlp.fc2" in name:
|
||||
name = name.replace("mlp.fc2", "output.dense")
|
||||
# final layernorm
|
||||
if "audio_spectrogram_transformer.norm" in name:
|
||||
name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
|
||||
# classifier head
|
||||
if "module.mlp_head.0" in name:
|
||||
name = name.replace("module.mlp_head.0", "classifier.layernorm")
|
||||
if "module.mlp_head.1" in name:
|
||||
name = name.replace("module.mlp_head.1", "classifier.dense")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if "qkv" in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.hidden_size
|
||||
if "weight" in key:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
|
||||
] = val[:dim, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
|
||||
] = val[dim : dim * 2, :]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
|
||||
] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
|
||||
] = val[:dim]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
|
||||
] = val[dim : dim * 2]
|
||||
orig_state_dict[
|
||||
f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
|
||||
] = val[-dim:]
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def remove_keys(state_dict):
|
||||
ignore_keys = [
|
||||
"module.v.head.weight",
|
||||
"module.v.head.bias",
|
||||
"module.v.head_dist.weight",
|
||||
"module.v.head_dist.bias",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
|
||||
"""
|
||||
config = get_audio_spectrogram_transformer_config(model_name)
|
||||
|
||||
model_name_to_url = {
|
||||
"ast-finetuned-audioset-10-10-0.4593": (
|
||||
"https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.450": (
|
||||
"https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448": (
|
||||
"https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-10-10-0.448-v2": (
|
||||
"https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-12-12-0.447": (
|
||||
"https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-14-14-0.443": (
|
||||
"https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-audioset-16-16-0.442": (
|
||||
"https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
|
||||
),
|
||||
"ast-finetuned-speech-commands-v2": (
|
||||
"https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
|
||||
),
|
||||
}
|
||||
|
||||
# load original state_dict
|
||||
checkpoint_url = model_name_to_url[model_name]
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove some keys
|
||||
remove_keys(state_dict)
|
||||
# rename some keys
|
||||
new_state_dict = convert_state_dict(state_dict, config)
|
||||
|
||||
# load 🤗 model
|
||||
model = ASTForAudioClassification(config)
|
||||
model.eval()
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify outputs on dummy input
|
||||
# source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
|
||||
mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
|
||||
std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
|
||||
max_length = 1024 if "speech-commands" not in model_name else 128
|
||||
feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint",
|
||||
filename="sample_audio.flac",
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
waveform, _ = torchaudio.load(filepath)
|
||||
waveform = waveform.squeeze().numpy()
|
||||
|
||||
inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
if model_name == "ast-finetuned-audioset-10-10-0.4593":
|
||||
expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.450":
|
||||
expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448":
|
||||
expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
|
||||
elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
|
||||
expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
|
||||
elif model_name == "ast-finetuned-audioset-12-12-0.447":
|
||||
expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
|
||||
elif model_name == "ast-finetuned-audioset-14-14-0.443":
|
||||
expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
|
||||
elif model_name == "ast-finetuned-audioset-16-16-0.442":
|
||||
expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
|
||||
elif model_name == "ast-finetuned-speech-commands-v2":
|
||||
expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError("Logits don't match")
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and feature extractor to the hub...")
|
||||
model.push_to_hub(f"MIT/{model_name}")
|
||||
feature_extractor.push_to_hub(f"MIT/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="ast-finetuned-audioset-10-10-0.4593",
|
||||
type=str,
|
||||
help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -493,7 +493,7 @@ class AutoImageProcessor:
|
||||
image_processor_auto_map = config.auto_map["AutoImageProcessor"]
|
||||
|
||||
image_processor_class = None
|
||||
# TODO: @yoni, change logic in v4.50 (when use_fast set to True by default)
|
||||
# TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
|
||||
if image_processor_type is not None:
|
||||
# if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor.
|
||||
if use_fast is None:
|
||||
@ -501,7 +501,7 @@ class AutoImageProcessor:
|
||||
if not use_fast:
|
||||
logger.warning_once(
|
||||
"Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
|
||||
"`use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. "
|
||||
"`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
|
||||
"This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
|
||||
)
|
||||
# Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version.
|
||||
|
@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
("gemma2", "Gemma2ForCausalLM"),
|
||||
("gemma3", "Gemma3ForCausalLM"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("gemma3_text", "Gemma3ForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("glm", "GlmForCausalLM"),
|
||||
@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||
"""
|
||||
return config.get_text_config(decoder=True)
|
||||
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
if hasattr(config, text_config_name):
|
||||
text_config_names += [text_config_name]
|
||||
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||
warnings.warn(
|
||||
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||
FutureWarning,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
@ -1,273 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from os import path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
|
||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
|
||||
state_dict = {}
|
||||
|
||||
for orig_k, param in original_sd.items():
|
||||
k = orig_k.replace("backbone", "model")
|
||||
|
||||
# for embeddings
|
||||
k = k.replace("embedding", "embed_tokens")
|
||||
|
||||
# for mixer
|
||||
k = k.replace("mixer", "mamba")
|
||||
|
||||
# for final layernorm
|
||||
k = k.replace("norm_f", "final_layernorm")
|
||||
|
||||
# for block layernorm
|
||||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
|
||||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
|
||||
|
||||
# for mlp
|
||||
k = k.replace("mlp.fc2", "feed_forward.down_proj")
|
||||
|
||||
if "mlp.fc1" in k:
|
||||
param, param2 = torch.chunk(param, 2, dim=0)
|
||||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
|
||||
state_dict[k2] = param2
|
||||
k = k.replace("mlp.fc1", "feed_forward.up_proj")
|
||||
|
||||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
|
||||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
|
||||
):
|
||||
# then this must be a mamba
|
||||
pass
|
||||
else:
|
||||
# for attn
|
||||
# - because mixer was replaced to mamba above
|
||||
k = k.replace("mamba.out_proj", "self_attn.o_proj")
|
||||
if "mamba.in_proj" in k:
|
||||
m, n = param.shape
|
||||
d = (m - n) // 2
|
||||
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
|
||||
state_dict[k2] = param2
|
||||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
|
||||
state_dict[k2] = param3
|
||||
k = k.replace("mamba.in_proj", "self_attn.q_proj")
|
||||
|
||||
state_dict[k] = param
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_ssm_config_to_hf_config(
|
||||
config_ssm: Dict,
|
||||
**kwargs,
|
||||
) -> BambaConfig:
|
||||
"""Convert a config from mamba_ssm to a BambaConfig from here."""
|
||||
hf_config: BambaConfig = BambaConfig(**kwargs)
|
||||
|
||||
hf_config.architectures = ["BambaForCausalLM"]
|
||||
|
||||
# Set important values from config and recalculate other resulting entries
|
||||
hf_config.hidden_size = config_ssm["d_model"]
|
||||
hf_config.intermediate_size = config_ssm["d_intermediate"]
|
||||
hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
|
||||
hf_config.num_hidden_layers = config_ssm["n_layer"]
|
||||
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
|
||||
|
||||
# currently this script assumes config_ssm belongs to v2
|
||||
if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
|
||||
raise ValueError("Conversion script only supports Mamba2")
|
||||
|
||||
# Set attention values
|
||||
attn_cfg = config_ssm.get("attn_cfg")
|
||||
if attn_cfg:
|
||||
assert attn_cfg["causal"], "Only support non-causal attention."
|
||||
assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
|
||||
assert not attn_cfg["out_proj_bias"], "Only support no out bias."
|
||||
hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
|
||||
hf_config.num_attention_heads = attn_cfg["num_heads"]
|
||||
hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
|
||||
|
||||
attention_layer_indices = config_ssm.get("attn_layer_idx")
|
||||
if attention_layer_indices:
|
||||
hf_config.attn_layer_indices = attention_layer_indices
|
||||
|
||||
# Padded vocab size, mostly of 16 but 32 is also very common in different models
|
||||
vocab_size = config_ssm["vocab_size"]
|
||||
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
|
||||
if (vocab_size % pad_vocab_size_multiple) != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
|
||||
hf_config.vocab_size = vocab_size
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def save_single_safetensor(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
):
|
||||
save_file(
|
||||
state_dict,
|
||||
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
|
||||
metadata,
|
||||
)
|
||||
|
||||
|
||||
def save_sharded_safetensors(
|
||||
state_dict: Dict,
|
||||
save_directory: str,
|
||||
metadata: Dict,
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
):
|
||||
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||
)
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
# Save the index
|
||||
with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
|
||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
mamba_ssm_checkpoint_path: str,
|
||||
precision: str,
|
||||
output_dir: str,
|
||||
tokenizer_path: str = None,
|
||||
save_model: Union[bool, str] = True,
|
||||
) -> None:
|
||||
# load tokenizer if provided, this will be used to set the
|
||||
# token_ids in the config file
|
||||
token_ids = {}
|
||||
if tokenizer_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
for key in [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
]:
|
||||
id = getattr(tokenizer, key, None)
|
||||
if id:
|
||||
token_ids[key] = id
|
||||
|
||||
# there are some configs unsettable by mamba_ssn config, so
|
||||
# if there are changes from the defaults, have to pass them into
|
||||
# the function
|
||||
unsettables = {
|
||||
"mamba_d_head": 64,
|
||||
"mamba_d_state": 128,
|
||||
"mamba_n_groups": 1,
|
||||
"rms_norm_eps": 1e-5,
|
||||
}
|
||||
|
||||
# Load and save config based on name
|
||||
config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as json_file:
|
||||
config = json.load(json_file)
|
||||
|
||||
# convert the config
|
||||
hf_config = convert_ssm_config_to_hf_config(
|
||||
config_ssm=config,
|
||||
**token_ids,
|
||||
**unsettables,
|
||||
)
|
||||
hf_config.save_pretrained(output_dir)
|
||||
|
||||
# Load state dict of the original model and transfer to hf model
|
||||
state_dict = torch.load(
|
||||
path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
# FIXME: allow other parameters to pass in
|
||||
state_dict = convert_state_dict_from_mamba_ssm(state_dict)
|
||||
|
||||
# Save new model to pytorch_dump_path
|
||||
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
|
||||
|
||||
save_file_fn = None
|
||||
if isinstance(save_model, bool) and save_model:
|
||||
save_file_fn = save_single_safetensor
|
||||
elif isinstance(save_model, str) and save_model == "sharded":
|
||||
save_file_fn = save_sharded_safetensors
|
||||
|
||||
if save_file_fn:
|
||||
save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
default="fp16",
|
||||
const="fp16",
|
||||
required=True,
|
||||
choices=("fp32", "fp16", "bf16"),
|
||||
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tokenizer_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to a the tokenizer file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
|
||||
args.mamba2_checkpoint_directory,
|
||||
args.precision,
|
||||
args.output_dir,
|
||||
)
|
@ -1,263 +0,0 @@
|
||||
"""Convert Bark checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from bark.generation import _load_model as _bark_load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import EncodecConfig, EncodecModel, set_seed
|
||||
from transformers.models.bark.configuration_bark import (
|
||||
BarkCoarseConfig,
|
||||
BarkConfig,
|
||||
BarkFineConfig,
|
||||
BarkSemanticConfig,
|
||||
)
|
||||
from transformers.models.bark.generation_configuration_bark import (
|
||||
BarkCoarseGenerationConfig,
|
||||
BarkFineGenerationConfig,
|
||||
BarkGenerationConfig,
|
||||
BarkSemanticGenerationConfig,
|
||||
)
|
||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
set_seed(770)
|
||||
|
||||
|
||||
new_layer_name_dict = {
|
||||
"c_attn": "att_proj",
|
||||
"c_proj": "out_proj",
|
||||
"c_fc": "in_proj",
|
||||
"transformer.": "",
|
||||
"h.": "layers.",
|
||||
"ln_1": "layernorm_1",
|
||||
"ln_2": "layernorm_2",
|
||||
"ln_f": "layernorm_final",
|
||||
"wpe": "position_embeds_layer",
|
||||
"wte": "input_embeds_layer",
|
||||
}
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text.pt",
|
||||
},
|
||||
"coarse_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse.pt",
|
||||
},
|
||||
"fine_small": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine.pt",
|
||||
},
|
||||
"text": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "text_2.pt",
|
||||
},
|
||||
"coarse": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "coarse_2.pt",
|
||||
},
|
||||
"fine": {
|
||||
"repo_id": "suno/bark",
|
||||
"file_name": "fine_2.pt",
|
||||
},
|
||||
}
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
key = model_type
|
||||
if use_small:
|
||||
key += "_small"
|
||||
return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if model_type == "text":
|
||||
ModelClass = BarkSemanticModel
|
||||
ConfigClass = BarkSemanticConfig
|
||||
GenerationConfigClass = BarkSemanticGenerationConfig
|
||||
elif model_type == "coarse":
|
||||
ModelClass = BarkCoarseModel
|
||||
ConfigClass = BarkCoarseConfig
|
||||
GenerationConfigClass = BarkCoarseGenerationConfig
|
||||
elif model_type == "fine":
|
||||
ModelClass = BarkFineModel
|
||||
ConfigClass = BarkFineConfig
|
||||
GenerationConfigClass = BarkFineGenerationConfig
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"])
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
|
||||
# convert Bark model arguments to HF Bark model arguments
|
||||
model_args["num_heads"] = model_args.pop("n_head")
|
||||
model_args["hidden_size"] = model_args.pop("n_embd")
|
||||
model_args["num_layers"] = model_args.pop("n_layer")
|
||||
|
||||
model_config = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(config=model_config)
|
||||
model_generation_config = GenerationConfigClass()
|
||||
|
||||
model.generation_config = model_generation_config
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
# replace part of the key with corresponding layer name in HF implementation
|
||||
new_k = k[len(unwanted_prefix) :]
|
||||
for old_layer_name in new_layer_name_dict:
|
||||
new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.num_parameters(exclude_embeddings=True)
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
|
||||
device = "cpu" # do conversion on cpu
|
||||
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
|
||||
|
||||
# load bark initial model
|
||||
bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
|
||||
|
||||
if model_type == "text":
|
||||
bark_model = bark_model["model"]
|
||||
|
||||
if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
|
||||
raise ValueError("initial and new models don't have the same number of parameters")
|
||||
|
||||
# check if same output as the bark model
|
||||
batch_size = 5
|
||||
sequence_length = 10
|
||||
|
||||
if model_type in ["text", "coarse"]:
|
||||
vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
|
||||
output_old_model = bark_model(vec)[0]
|
||||
|
||||
output_new_model_total = model(vec)
|
||||
|
||||
# take last logits
|
||||
output_new_model = output_new_model_total.logits[:, [-1], :]
|
||||
|
||||
else:
|
||||
prediction_codeboook_channel = 3
|
||||
n_codes_total = 8
|
||||
vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
|
||||
|
||||
output_new_model_total = model(prediction_codeboook_channel, vec)
|
||||
output_old_model = bark_model(prediction_codeboook_channel, vec)
|
||||
|
||||
output_new_model = output_new_model_total.logits
|
||||
|
||||
# output difference should come from the difference of self-attention implementation design
|
||||
if output_new_model.shape != output_old_model.shape:
|
||||
raise ValueError("initial and new outputs don't have the same shape")
|
||||
if (output_new_model - output_old_model).abs().max().item() > 1e-3:
|
||||
raise ValueError("initial and new outputs are not equal")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_whole_bark_model(
|
||||
semantic_path,
|
||||
coarse_path,
|
||||
fine_path,
|
||||
append_text,
|
||||
hub_path,
|
||||
folder_path,
|
||||
):
|
||||
pytorch_dump_folder_path = os.path.join(folder_path, append_text)
|
||||
|
||||
semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
|
||||
coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
|
||||
fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
|
||||
codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
semantic = BarkSemanticModel.from_pretrained(semantic_path)
|
||||
coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
|
||||
fineAcoustic = BarkFineModel.from_pretrained(fine_path)
|
||||
codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
|
||||
|
||||
bark_config = BarkConfig.from_sub_model_configs(
|
||||
semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
|
||||
)
|
||||
|
||||
bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
|
||||
semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
|
||||
)
|
||||
|
||||
bark = BarkModel(bark_config)
|
||||
|
||||
bark.semantic = semantic
|
||||
bark.coarse_acoustics = coarseAcoustic
|
||||
bark.fine_acoustics = fineAcoustic
|
||||
bark.codec_model = codec
|
||||
|
||||
bark.generation_config = bark_generation_config
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
|
||||
parser.add_argument("model_type", type=str, help="text, coarse or fine.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BART checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||
|
||||
mnli_rename_keys = [
|
||||
("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
|
||||
("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
|
||||
("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
|
||||
("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
|
||||
]
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"_float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def load_xsum_checkpoint(checkpoint_path):
|
||||
"""Checkpoint path should end in model.pt"""
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
|
||||
hub_interface.model.load_state_dict(sd["model"])
|
||||
return hub_interface
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
if not os.path.exists(checkpoint_path):
|
||||
bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
|
||||
else:
|
||||
bart = load_xsum_checkpoint(checkpoint_path)
|
||||
|
||||
bart.model.upgrade_state_dict(bart.model.state_dict())
|
||||
if hf_checkpoint_name is None:
|
||||
hf_checkpoint_name = checkpoint_path.replace(".", "-")
|
||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||
if not torch.eq(tokens, tokens2).all():
|
||||
raise ValueError(
|
||||
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||
)
|
||||
|
||||
if checkpoint_path == "bart.large.mnli":
|
||||
state_dict = bart.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
|
||||
for src, dest in mnli_rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
model = BartForSequenceClassification(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
fairseq_output = bart.predict("mnli", tokens, return_logits=True)
|
||||
new_model_outputs = model(tokens)[0] # logits
|
||||
else: # no classification heads to worry about
|
||||
state_dict = bart.model.state_dict()
|
||||
remove_ignore_keys_(state_dict)
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
fairseq_output = bart.extract_features(tokens)
|
||||
if hf_checkpoint_name == "facebook/bart-large":
|
||||
model = BartModel(config).eval()
|
||||
model.load_state_dict(state_dict)
|
||||
new_model_outputs = model(tokens).model[0]
|
||||
else:
|
||||
model = BartForConditionalGeneration(config).eval() # an existing summarization ckpt
|
||||
model.model.load_state_dict(state_dict)
|
||||
if hasattr(model, "lm_head"):
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
new_model_outputs = model.model(tokens)[0]
|
||||
|
||||
# Check results
|
||||
if fairseq_output.shape != new_model_outputs.shape:
|
||||
raise ValueError(
|
||||
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||
)
|
||||
if (fairseq_output != new_model_outputs).any().item():
|
||||
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
|
||||
)
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
|
@ -1,373 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BEiT checkpoints from the unilm repository."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
BeitConfig,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitImageProcessor,
|
||||
)
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + shared relative position bias + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
(
|
||||
"rel_pos_bias.relative_position_bias_table",
|
||||
"beit.encoder.relative_position_bias.relative_position_bias_table",
|
||||
),
|
||||
(
|
||||
"rel_pos_bias.relative_position_index",
|
||||
"beit.encoder.relative_position_bias.relative_position_index",
|
||||
),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
elif is_semantic:
|
||||
# semantic segmentation classification heads
|
||||
rename_keys.extend(
|
||||
[
|
||||
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
|
||||
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
|
||||
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
|
||||
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
# relative_position bias table + index
|
||||
if not has_lm_head:
|
||||
# each layer has its own relative position bias
|
||||
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
|
||||
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
|
||||
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
|
||||
] = table
|
||||
state_dict[
|
||||
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
|
||||
] = index
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
config = BeitConfig()
|
||||
has_lm_head = False
|
||||
is_semantic = False
|
||||
repo_id = "huggingface/label-files"
|
||||
# set config parameters based on URL
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
# masked image modeling
|
||||
config.use_shared_relative_position_bias = True
|
||||
config.use_mask_token = True
|
||||
has_lm_head = True
|
||||
elif checkpoint_url[-9:-4] == "ft22k":
|
||||
# intermediate fine-tuning on ImageNet-22k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
elif checkpoint_url[-8:-4] == "to1k":
|
||||
# fine-tuning on ImageNet-1k
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
if "384" in checkpoint_url:
|
||||
config.image_size = 384
|
||||
if "512" in checkpoint_url:
|
||||
config.image_size = 512
|
||||
elif "ade20k" in checkpoint_url:
|
||||
# fine-tuning
|
||||
config.use_relative_position_bias = True
|
||||
config.num_labels = 150
|
||||
filename = "ade20k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.image_size = 640
|
||||
is_semantic = True
|
||||
else:
|
||||
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
|
||||
|
||||
# size of the architecture
|
||||
if "base" in checkpoint_url:
|
||||
pass
|
||||
elif "large" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
if "ade20k" in checkpoint_url:
|
||||
config.image_size = 640
|
||||
config.out_indices = [7, 11, 15, 23]
|
||||
else:
|
||||
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
|
||||
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
|
||||
if is_semantic:
|
||||
# add prefix to decoder keys
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("backbone.fpn"):
|
||||
key = key.replace("backbone.fpn", "fpn")
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
if checkpoint_url[-9:-4] == "pt22k":
|
||||
model = BeitForMaskedImageModeling(config)
|
||||
elif "ade20k" in checkpoint_url:
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
else:
|
||||
model = BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
|
||||
)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = torch.Size([1, 1000])
|
||||
if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
|
||||
expected_shape = torch.Size([1, 196, 8192])
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
|
||||
expected_class_idx = 2397
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
|
||||
expected_shape = torch.Size([1, 21841])
|
||||
expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
|
||||
expected_class_idx = 2396
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
|
||||
expected_class_idx = 285
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
|
||||
expected_class_idx = 281
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
|
||||
expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
|
||||
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
|
||||
expected_class_idx = 761
|
||||
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
]
|
||||
)
|
||||
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
|
||||
expected_shape = (1, 150, 160, 160)
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
|
||||
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
|
||||
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
|
||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
|
||||
|
||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
|
||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
|
||||
|
||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||
|
||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
|
||||
Models trained with never versions are not compatible with this script.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
layer_depth = []
|
||||
for full_name, shape in init_vars:
|
||||
# logger.info(f"Loading TF weight {name} with shape {shape}")
|
||||
name = full_name.split("/")
|
||||
if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
|
||||
logger.info(f"Skipping non-model layer {full_name}")
|
||||
continue
|
||||
if "optimizer" in full_name:
|
||||
logger.info(f"Skipping optimization layer {full_name}")
|
||||
continue
|
||||
if name[0] == "model":
|
||||
# ignore initial 'model'
|
||||
name = name[1:]
|
||||
# figure out how many levels deep the name is
|
||||
depth = 0
|
||||
for _name in name:
|
||||
if _name.startswith("layer_with_weights"):
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
layer_depth.append(depth)
|
||||
# read data
|
||||
array = tf.train.load_variable(tf_path, full_name)
|
||||
names.append("/".join(name))
|
||||
arrays.append(array)
|
||||
logger.info(f"Read a total of {len(arrays):,} layers")
|
||||
|
||||
# Sanity check
|
||||
if len(set(layer_depth)) != 1:
|
||||
raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
|
||||
layer_depth = list(set(layer_depth))[0]
|
||||
if layer_depth != 1:
|
||||
raise ValueError(
|
||||
"The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
|
||||
" heads."
|
||||
)
|
||||
|
||||
# convert layers
|
||||
logger.info("Converting weights...")
|
||||
for full_name, array in zip(names, arrays):
|
||||
name = full_name.split("/")
|
||||
pointer = model
|
||||
trace = []
|
||||
for i, m_name in enumerate(name):
|
||||
if m_name == ".ATTRIBUTES":
|
||||
# variable names end with .ATTRIBUTES/VARIABLE_VALUE
|
||||
break
|
||||
if m_name.startswith("layer_with_weights"):
|
||||
layer_num = int(m_name.split("-")[-1])
|
||||
if layer_num <= 2:
|
||||
# embedding layers
|
||||
# layer_num 0: word_embeddings
|
||||
# layer_num 1: position_embeddings
|
||||
# layer_num 2: token_type_embeddings
|
||||
continue
|
||||
elif layer_num == 3:
|
||||
# embedding LayerNorm
|
||||
trace.extend(["embeddings", "LayerNorm"])
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
|
||||
# encoder layers
|
||||
trace.extend(["encoder", "layer", str(layer_num - 4)])
|
||||
pointer = getattr(pointer, "encoder")
|
||||
pointer = getattr(pointer, "layer")
|
||||
pointer = pointer[layer_num - 4]
|
||||
elif layer_num == config.num_hidden_layers + 4:
|
||||
# pooler layer
|
||||
trace.extend(["pooler", "dense"])
|
||||
pointer = getattr(pointer, "pooler")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "embeddings":
|
||||
trace.append("embeddings")
|
||||
pointer = getattr(pointer, "embeddings")
|
||||
if layer_num == 0:
|
||||
trace.append("word_embeddings")
|
||||
pointer = getattr(pointer, "word_embeddings")
|
||||
elif layer_num == 1:
|
||||
trace.append("position_embeddings")
|
||||
pointer = getattr(pointer, "position_embeddings")
|
||||
elif layer_num == 2:
|
||||
trace.append("token_type_embeddings")
|
||||
pointer = getattr(pointer, "token_type_embeddings")
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding layer with name {full_name}")
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif m_name == "_attention_layer":
|
||||
# self-attention layer
|
||||
trace.extend(["attention", "self"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "self")
|
||||
elif m_name == "_attention_layer_norm":
|
||||
# output attention norm
|
||||
trace.extend(["attention", "output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_attention_output_dense":
|
||||
# output attention dense
|
||||
trace.extend(["attention", "output", "dense"])
|
||||
pointer = getattr(pointer, "attention")
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_dense":
|
||||
# output dense
|
||||
trace.extend(["output", "dense"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output dense
|
||||
trace.extend(["output", "LayerNorm"])
|
||||
pointer = getattr(pointer, "output")
|
||||
pointer = getattr(pointer, "LayerNorm")
|
||||
elif m_name == "_key_dense":
|
||||
# attention key
|
||||
trace.append("key")
|
||||
pointer = getattr(pointer, "key")
|
||||
elif m_name == "_query_dense":
|
||||
# attention query
|
||||
trace.append("query")
|
||||
pointer = getattr(pointer, "query")
|
||||
elif m_name == "_value_dense":
|
||||
# attention value
|
||||
trace.append("value")
|
||||
pointer = getattr(pointer, "value")
|
||||
elif m_name == "_intermediate_dense":
|
||||
# attention intermediate dense
|
||||
trace.extend(["intermediate", "dense"])
|
||||
pointer = getattr(pointer, "intermediate")
|
||||
pointer = getattr(pointer, "dense")
|
||||
elif m_name == "_output_layer_norm":
|
||||
# output layer norm
|
||||
trace.append("output")
|
||||
pointer = getattr(pointer, "output")
|
||||
# weights & biases
|
||||
elif m_name in ["bias", "beta"]:
|
||||
trace.append("bias")
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif m_name in ["kernel", "gamma"]:
|
||||
trace.append("weight")
|
||||
pointer = getattr(pointer, "weight")
|
||||
else:
|
||||
logger.warning(f"Ignored {m_name}")
|
||||
# for certain layers reshape is necessary
|
||||
trace = ".".join(trace)
|
||||
if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
|
||||
r"(\S+)\.attention\.output\.dense\.weight", trace
|
||||
):
|
||||
array = array.reshape(pointer.data.shape)
|
||||
if "kernel" in full_name:
|
||||
array = array.transpose()
|
||||
if pointer.shape == array.shape:
|
||||
pointer.data = torch.from_numpy(array)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
|
||||
f" {array.shape}"
|
||||
)
|
||||
logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
|
||||
return model
|
||||
|
||||
|
||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
|
||||
# Instantiate model
|
||||
logger.info(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertModel(config)
|
||||
|
||||
# Load weights from checkpoint
|
||||
logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
|
||||
load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
|
||||
|
||||
# Save pytorch-model
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model (must include filename).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,62 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,112 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
"""
|
||||
Args:
|
||||
model: BertModel Pytorch model instance to be converted
|
||||
ckpt_dir: Tensorflow model directory
|
||||
model_name: model name
|
||||
|
||||
Currently supported HF models:
|
||||
|
||||
- Y BertModel
|
||||
- N BertForMaskedLM
|
||||
- N BertForPreTraining
|
||||
- N BertForMultipleChoice
|
||||
- N BertForNextSentencePrediction
|
||||
- N BertForSequenceClassification
|
||||
- N BertForQuestionAnswering
|
||||
"""
|
||||
|
||||
tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
|
||||
|
||||
var_map = (
|
||||
("layer.", "layer_"),
|
||||
("word_embeddings.weight", "word_embeddings"),
|
||||
("position_embeddings.weight", "position_embeddings"),
|
||||
("token_type_embeddings.weight", "token_type_embeddings"),
|
||||
(".", "/"),
|
||||
("LayerNorm/weight", "LayerNorm/gamma"),
|
||||
("LayerNorm/bias", "LayerNorm/beta"),
|
||||
("weight", "kernel"),
|
||||
)
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def to_tf_var_name(name: str):
|
||||
for patt, repl in iter(var_map):
|
||||
name = name.replace(patt, repl)
|
||||
return f"bert/{name}"
|
||||
|
||||
def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
|
||||
tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
|
||||
tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
|
||||
session.run(tf.variables_initializer([tf_var]))
|
||||
session.run(tf_var)
|
||||
return tf_var
|
||||
|
||||
tf.reset_default_graph()
|
||||
with tf.Session() as session:
|
||||
for var_name in state_dict:
|
||||
tf_name = to_tf_var_name(var_name)
|
||||
torch_tensor = state_dict[var_name].numpy()
|
||||
if any(x in var_name for x in tensors_to_transpose):
|
||||
torch_tensor = torch_tensor.T
|
||||
tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
|
||||
tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
|
||||
tf_weight = session.run(tf_var)
|
||||
print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
|
||||
|
||||
saver = tf.train.Saver(tf.trainable_variables())
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
|
||||
)
|
||||
parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,188 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
|
||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
|
||||
|
||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForMaskedLM
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertPooler,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
|
||||
def get_masked_lm_array(name: str):
|
||||
full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_array(name: str):
|
||||
full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_layer_array(layer_index: int, name: str):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
|
||||
full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
array = tf.train.load_variable(tf_checkpoint_path, full_name)
|
||||
array = array.reshape(orginal_shape)
|
||||
|
||||
if "kernel" in name:
|
||||
array = array.transpose()
|
||||
|
||||
return torch.from_numpy(array)
|
||||
|
||||
print(f"Loading model based on config from {config_path}...")
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
model = BertForMaskedLM(config)
|
||||
|
||||
# Layers
|
||||
for layer_index in range(0, config.num_hidden_layers):
|
||||
layer: BertLayer = model.bert.encoder.layer[layer_index]
|
||||
|
||||
# Self-attention
|
||||
self_attn: BertSelfAttention = layer.attention.self
|
||||
|
||||
self_attn.query.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
|
||||
)
|
||||
self_attn.query.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
|
||||
)
|
||||
self_attn.key.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
|
||||
)
|
||||
self_attn.key.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
|
||||
)
|
||||
self_attn.value.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
|
||||
)
|
||||
self_attn.value.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
|
||||
)
|
||||
|
||||
# Self-attention Output
|
||||
self_output: BertSelfOutput = layer.attention.output
|
||||
|
||||
self_output.dense.weight.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
|
||||
)
|
||||
self_output.dense.bias.data = get_encoder_attention_layer_array(
|
||||
layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
|
||||
)
|
||||
|
||||
self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
|
||||
self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
|
||||
|
||||
# Intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
|
||||
intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
|
||||
intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
|
||||
|
||||
# Output
|
||||
bert_output: BertOutput = layer.output
|
||||
|
||||
bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
|
||||
bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
|
||||
|
||||
bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
|
||||
bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
|
||||
|
||||
# Embeddings
|
||||
model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
|
||||
model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
|
||||
model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
|
||||
model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
|
||||
|
||||
# LM Head
|
||||
lm_head = model.cls.predictions.transform
|
||||
|
||||
lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
|
||||
lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
|
||||
|
||||
lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
|
||||
lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
|
||||
|
||||
model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
|
||||
|
||||
# Pooling
|
||||
model.bert.pooler = BertPooler(config=config)
|
||||
model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
|
||||
model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
|
||||
|
||||
# Export final model
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Integration test - should load without any errors ;)
|
||||
new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
|
||||
print(new_model.eval())
|
||||
|
||||
print("Model conversion was done sucessfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the BERT model. This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
|
@ -1,69 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigBird checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
|
||||
# Initialise PyTorch model
|
||||
config = BigBirdConfig.from_json_file(big_bird_config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
if is_trivia_qa:
|
||||
model = BigBirdForQuestionAnswering(config)
|
||||
else:
|
||||
model = BigBirdForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big_bird_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained BERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
|
||||
)
|
@ -1,170 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
|
||||
|
||||
|
||||
INIT_COMMON = [
|
||||
# tf -> hf
|
||||
("/", "."),
|
||||
("layer_", "layers."),
|
||||
("kernel", "weight"),
|
||||
("beta", "bias"),
|
||||
("gamma", "weight"),
|
||||
("pegasus", "model"),
|
||||
]
|
||||
END_COMMON = [
|
||||
(".output.dense", ".fc2"),
|
||||
("intermediate.LayerNorm", "final_layer_norm"),
|
||||
("intermediate.dense", "fc1"),
|
||||
]
|
||||
|
||||
DECODER_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.out_proj"),
|
||||
("attention.self", "self_attn"),
|
||||
("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
|
||||
("attention.encdec_output.dense", "encoder_attn.out_proj"),
|
||||
("attention.encdec", "encoder_attn"),
|
||||
("key", "k_proj"),
|
||||
("value", "v_proj"),
|
||||
("query", "q_proj"),
|
||||
("decoder.LayerNorm", "decoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
REMAINING_PATTERNS = (
|
||||
INIT_COMMON
|
||||
+ [
|
||||
("embeddings.word_embeddings", "shared.weight"),
|
||||
("embeddings.position_embeddings", "embed_positions.weight"),
|
||||
("attention.self.LayerNorm", "self_attn_layer_norm"),
|
||||
("attention.output.dense", "self_attn.output"),
|
||||
("attention.self", "self_attn.self"),
|
||||
("encoder.LayerNorm", "encoder.layernorm_embedding"),
|
||||
]
|
||||
+ END_COMMON
|
||||
)
|
||||
|
||||
KEYS_TO_IGNORE = [
|
||||
"encdec/key/bias",
|
||||
"encdec/query/bias",
|
||||
"encdec/value/bias",
|
||||
"self/key/bias",
|
||||
"self/query/bias",
|
||||
"self/value/bias",
|
||||
"encdec_output/dense/bias",
|
||||
"attention/output/dense/bias",
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k, patterns):
|
||||
for tf_name, hf_name in patterns:
|
||||
k = k.replace(tf_name, hf_name)
|
||||
return k
|
||||
|
||||
|
||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
|
||||
cfg = BigBirdPegasusConfig(**config_update)
|
||||
torch_model = BigBirdPegasusForConditionalGeneration(cfg)
|
||||
state_dict = torch_model.state_dict()
|
||||
mapping = {}
|
||||
|
||||
# separating decoder weights
|
||||
decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
|
||||
remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
|
||||
|
||||
for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = DECODER_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict:
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
|
||||
conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
|
||||
if any(conditions):
|
||||
continue
|
||||
patterns = REMAINING_PATTERNS
|
||||
new_k = rename_state_dict_key(k, patterns)
|
||||
if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
|
||||
raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
|
||||
if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
|
||||
v = v.T
|
||||
mapping[new_k] = torch.from_numpy(v)
|
||||
if k != "pegasus/embeddings/position_embeddings":
|
||||
assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
|
||||
|
||||
mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
|
||||
mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
|
||||
missing, extra = torch_model.load_state_dict(mapping, strict=False)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if k
|
||||
not in [
|
||||
"final_logits_bias",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
]
|
||||
assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
|
||||
assert extra == [], f"no matches found for the following tf keys {extra}"
|
||||
return torch_model
|
||||
|
||||
|
||||
def get_tf_weights_as_numpy(path) -> Dict:
|
||||
init_vars = tf.train.list_variables(path)
|
||||
tf_weights = {}
|
||||
ignore_name = ["global_step"]
|
||||
for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
|
||||
skip_key = any(pat in name for pat in ignore_name)
|
||||
if skip_key:
|
||||
continue
|
||||
array = tf.train.load_variable(path, name)
|
||||
tf_weights[name] = array
|
||||
return tf_weights
|
||||
|
||||
|
||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
torch_model = convert_bigbird_pegasus(tf_weights, config_update)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
|
||||
parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
config_update = {}
|
||||
convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
|
@ -1,292 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BioGptConfig, BioGptForCausalLM
|
||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
json_indent = 2
|
||||
|
||||
|
||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
|
||||
class Dictionary:
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(f, "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
|
||||
|
||||
|
||||
def rewrite_dict_keys(d):
|
||||
# (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
|
||||
# e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
|
||||
d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
|
||||
keep_keys = "<s> <pad> </s> <unk>".split()
|
||||
# restore the special tokens
|
||||
for k in keep_keys:
|
||||
del d2[f"{k}</w>"]
|
||||
d2[k] = d[k] # restore
|
||||
return d2
|
||||
|
||||
|
||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
|
||||
# prep
|
||||
if not os.path.exists(biogpt_checkpoint_path):
|
||||
raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
print(f"Writing results to {pytorch_dump_folder_path}")
|
||||
|
||||
# handle various types of models
|
||||
|
||||
checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
|
||||
if not os.path.isfile(checkpoint_file):
|
||||
raise ValueError(f"path to the file {checkpoint_file} does not exist!")
|
||||
chkpt = torch.load(checkpoint_file, map_location="cpu")
|
||||
|
||||
args = chkpt["cfg"]["model"]
|
||||
|
||||
# dicts
|
||||
dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
|
||||
if not os.path.isfile(dict_file):
|
||||
raise ValueError(f"path to the file {dict_file} does not exist!")
|
||||
src_dict = Dictionary.load(dict_file)
|
||||
src_vocab = rewrite_dict_keys(src_dict.indices)
|
||||
src_vocab_size = len(src_vocab)
|
||||
src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
print(f"Generating {src_vocab_file} of {src_vocab_size} records")
|
||||
with open(src_vocab_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# merges_file (bpecodes)
|
||||
bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
|
||||
if not os.path.isfile(bpecodes_file):
|
||||
raise ValueError(f"path to the file {bpecodes_file} does not exist!")
|
||||
|
||||
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
|
||||
shutil.copyfile(bpecodes_file, merges_file)
|
||||
|
||||
# model config
|
||||
biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
|
||||
|
||||
model_conf = {
|
||||
"activation_dropout": args["activation_dropout"],
|
||||
"architectures": ["BioGptForCausalLM"],
|
||||
"attention_probs_dropout_prob": args["attention_dropout"],
|
||||
"bos_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": args["activation_fn"],
|
||||
"hidden_dropout_prob": args["dropout"],
|
||||
"hidden_size": args["decoder_embed_dim"],
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": args["decoder_ffn_embed_dim"],
|
||||
"layer_norm_eps": 1e-12,
|
||||
"layerdrop": args["decoder_layerdrop"],
|
||||
"max_position_embeddings": args["max_target_positions"],
|
||||
"model_type": "biogpt",
|
||||
"num_attention_heads": args["decoder_attention_heads"],
|
||||
"num_hidden_layers": args["decoder_layers"],
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": not args["no_scale_embedding"],
|
||||
"tie_word_embeddings": args["share_decoder_input_output_embed"],
|
||||
"vocab_size": src_vocab_size,
|
||||
}
|
||||
|
||||
# good hparam defaults to start with
|
||||
|
||||
print(f"Generating {biogpt_model_config_file}")
|
||||
with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# tokenizer config
|
||||
biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
|
||||
|
||||
tokenizer_conf = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1024,
|
||||
"pad_token": "<pad>",
|
||||
"special_tokens_map_file": None,
|
||||
"tokenizer_class": "BioGptTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
|
||||
print(f"Generating {biogpt_tokenizer_config_file}")
|
||||
with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
|
||||
|
||||
# model
|
||||
model_state_dict = chkpt["model"]
|
||||
|
||||
# remove unneeded keys
|
||||
ignore_keys = [
|
||||
"decoder.version",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
model_state_dict.pop(k, None)
|
||||
|
||||
layer_names = list(model_state_dict.keys())
|
||||
for layer_name in layer_names:
|
||||
if layer_name.endswith("output_projection.weight"):
|
||||
model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
|
||||
else:
|
||||
model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
|
||||
|
||||
config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
|
||||
model_new = BioGptForCausalLM(config)
|
||||
|
||||
# check that it loads ok
|
||||
model_new.load_state_dict(model_state_dict)
|
||||
|
||||
# save
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
print(f"Generating {pytorch_weights_dump_path}")
|
||||
torch.save(model_state_dict, pytorch_weights_dump_path)
|
||||
|
||||
print("Conversion is done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--biogpt_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
|
||||
" bpecodes, etc."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,177 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BiT checkpoints from the timm library."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config
|
||||
from timm.data.transforms_factory import create_transform
|
||||
|
||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_config(model_name):
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
conv_layer = "std_conv" if "bit" in model_name else False
|
||||
|
||||
# note that when using BiT as backbone for ViT-hybrid checkpoints,
|
||||
# one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
|
||||
# config.conv_layer = "std_conv_same"
|
||||
config = BitConfig(
|
||||
conv_layer=conv_layer,
|
||||
num_labels=1000,
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "stem.conv" in name:
|
||||
name = name.replace("stem.conv", "bit.embedder.convolution")
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "layers")
|
||||
if "head.fc" in name:
|
||||
name = name.replace("head.fc", "classifier.1")
|
||||
if name.startswith("norm"):
|
||||
name = "bit." + name
|
||||
if "bit" not in name and "classifier" not in name:
|
||||
name = "bit.encoder." + name
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BiT structure.
|
||||
"""
|
||||
|
||||
# define default BiT configuration
|
||||
config = get_config(model_name)
|
||||
|
||||
# load original model from timm
|
||||
timm_model = create_model(model_name, pretrained=True)
|
||||
timm_model.eval()
|
||||
|
||||
# load state_dict of original model
|
||||
state_dict = timm_model.state_dict()
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
|
||||
|
||||
# load HuggingFace model
|
||||
model = BitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# create image processor
|
||||
transform = create_transform(**resolve_data_config({}, model=timm_model))
|
||||
timm_transforms = transform.transforms
|
||||
|
||||
pillow_resamplings = {
|
||||
"bilinear": PILImageResampling.BILINEAR,
|
||||
"bicubic": PILImageResampling.BICUBIC,
|
||||
"nearest": PILImageResampling.NEAREST,
|
||||
}
|
||||
|
||||
processor = BitImageProcessor(
|
||||
do_resize=True,
|
||||
size={"shortest_edge": timm_transforms[0].size},
|
||||
resample=pillow_resamplings[timm_transforms[0].interpolation.value],
|
||||
do_center_crop=True,
|
||||
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
|
||||
do_normalize=True,
|
||||
image_mean=timm_transforms[-1].mean.tolist(),
|
||||
image_std=timm_transforms[-1].std.tolist(),
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
timm_pixel_values = transform(image).unsqueeze(0)
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
# verify pixel values
|
||||
assert torch.allclose(timm_pixel_values, pixel_values)
|
||||
|
||||
# verify logits
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
print("Logits:", logits[0, :3])
|
||||
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
|
||||
timm_logits = timm_model(pixel_values)
|
||||
assert timm_logits.shape == outputs.logits.shape
|
||||
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model {model_name} and processor to the hub")
|
||||
model.push_to_hub(f"ybelkada/{model_name}")
|
||||
processor.push_to_hub(f"ybelkada/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="resnetv2_50x1_bitm",
|
||||
type=str,
|
||||
help="Name of the BiT timm model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,114 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Blenderbot checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
PATTERNS = [
|
||||
["attention", "attn"],
|
||||
["encoder_attention", "encoder_attn"],
|
||||
["q_lin", "q_proj"],
|
||||
["k_lin", "k_proj"],
|
||||
["v_lin", "v_proj"],
|
||||
["out_lin", "out_proj"],
|
||||
["norm_embeddings", "layernorm_embedding"],
|
||||
["position_embeddings", "embed_positions"],
|
||||
["embeddings", "embed_tokens"],
|
||||
["ffn.lin", "fc"],
|
||||
]
|
||||
|
||||
|
||||
def rename_state_dict_key(k):
|
||||
if k == "embeddings.weight":
|
||||
return "shared.weight"
|
||||
|
||||
for parlai_name, hf_name in PATTERNS:
|
||||
k = k.replace(parlai_name, hf_name)
|
||||
|
||||
if k.startswith("encoder"):
|
||||
k = k.replace(".attn", ".self_attn")
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "final_layer_norm")
|
||||
elif k.startswith("decoder"):
|
||||
k = k.replace("norm1", "self_attn_layer_norm")
|
||||
k = k.replace("norm2", "encoder_attn_layer_norm")
|
||||
k = k.replace("norm3", "final_layer_norm")
|
||||
return k
|
||||
|
||||
|
||||
def rename_layernorm_keys(sd):
|
||||
keys = [
|
||||
"model.encoder.layernorm_embedding.weight",
|
||||
"model.encoder.layernorm_embedding.bias",
|
||||
"model.decoder.layernorm_embedding.weight",
|
||||
"model.decoder.layernorm_embedding.bias",
|
||||
]
|
||||
for k in keys:
|
||||
v = sd.pop(k)
|
||||
new_k = k.replace("layernorm_embedding", "layer_norm")
|
||||
assert new_k not in sd
|
||||
sd[new_k] = v
|
||||
|
||||
|
||||
IGNORE_KEYS = ["START"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BERT structure.
|
||||
"""
|
||||
model = torch.load(checkpoint_path, map_location="cpu")
|
||||
sd = model["model"]
|
||||
cfg = BlenderbotConfig.from_json_file(config_json_path)
|
||||
m = BlenderbotForConditionalGeneration(cfg)
|
||||
valid_keys = m.model.state_dict().keys()
|
||||
failures = []
|
||||
mapping = {}
|
||||
for k, v in sd.items():
|
||||
if k in IGNORE_KEYS:
|
||||
continue
|
||||
|
||||
new_k = rename_state_dict_key(k)
|
||||
if new_k not in valid_keys:
|
||||
failures.append([k, new_k])
|
||||
else:
|
||||
mapping[new_k] = v
|
||||
if cfg.normalize_before: # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
|
||||
rename_layernorm_keys(sd)
|
||||
m.model.load_state_dict(mapping, strict=True)
|
||||
m.half()
|
||||
m.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
|
||||
parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
|
||||
parser.add_argument(
|
||||
"--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
|
@ -1,191 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# git clone https://github.com/salesforce/BLIP.git
|
||||
from models.blip import blip_decoder
|
||||
from models.blip_itm import blip_itm
|
||||
from models.blip_vqa import blip_vqa
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from transformers import (
|
||||
BertTokenizer,
|
||||
BlipConfig,
|
||||
BlipForConditionalGeneration,
|
||||
BlipForImageTextRetrieval,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
def load_demo_image(image_size, device):
|
||||
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
]
|
||||
)
|
||||
image = transform(raw_image).unsqueeze(0).to(device)
|
||||
return image
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
if "visual_encoder" in key:
|
||||
key = re.sub("visual_encoder*", "vision_model.encoder", key)
|
||||
if "blocks" in key:
|
||||
key = re.sub(r"blocks", "layers", key)
|
||||
if "attn" in key:
|
||||
key = re.sub(r"attn", "self_attn", key)
|
||||
if "norm1" in key:
|
||||
key = re.sub(r"norm1", "layer_norm1", key)
|
||||
if "norm2" in key:
|
||||
key = re.sub(r"norm2", "layer_norm2", key)
|
||||
if "encoder.norm" in key:
|
||||
key = re.sub(r"encoder.norm", "post_layernorm", key)
|
||||
if "encoder.patch_embed.proj" in key:
|
||||
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
|
||||
|
||||
if "encoder.pos_embed" in key:
|
||||
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
|
||||
if "encoder.cls_token" in key:
|
||||
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
|
||||
|
||||
if "self_attn" in key:
|
||||
key = re.sub(r"self_attn.proj", "self_attn.projection", key)
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = BlipConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = BlipForConditionalGeneration(config).eval()
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
|
||||
|
||||
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
modified_state_dict = pt_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_model.load_state_dict(modified_state_dict)
|
||||
|
||||
image_size = 384
|
||||
image = load_demo_image(image_size=image_size, device="cpu")
|
||||
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
||||
input_ids = tokenizer(["a picture of"]).input_ids
|
||||
|
||||
out = hf_model.generate(image, input_ids)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
out = hf_model.generate(image)
|
||||
|
||||
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
|
||||
model_url = (
|
||||
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
|
||||
)
|
||||
|
||||
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
|
||||
vqa_model.eval()
|
||||
|
||||
modified_state_dict = vqa_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_vqa_model = BlipForQuestionAnswering(config)
|
||||
|
||||
hf_vqa_model.load_state_dict(modified_state_dict)
|
||||
|
||||
question = ["How many dogs are in this image?"]
|
||||
question_input_ids = tokenizer(question, return_tensors="pt").input_ids
|
||||
|
||||
answer = hf_vqa_model.generate(question_input_ids, image)
|
||||
print(tokenizer.decode(answer[0]))
|
||||
|
||||
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
|
||||
|
||||
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
|
||||
|
||||
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
|
||||
itm_model.eval()
|
||||
|
||||
modified_state_dict = itm_model.state_dict()
|
||||
for key in modified_state_dict.copy():
|
||||
value = modified_state_dict.pop(key)
|
||||
renamed_key = rename_key(key)
|
||||
modified_state_dict[renamed_key] = value
|
||||
|
||||
hf_itm_model = BlipForImageTextRetrieval(config)
|
||||
|
||||
question = ["A picture of a woman with a dog sitting in a beach"]
|
||||
question_input_ids = tokenizer(
|
||||
question,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=35,
|
||||
).input_ids
|
||||
|
||||
hf_itm_model.load_state_dict(modified_state_dict)
|
||||
hf_itm_model.eval()
|
||||
|
||||
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
|
||||
out = hf_itm_model(question_input_ids, image, use_itm_head=False)
|
||||
|
||||
assert out[0].item() == 0.2110687494277954
|
||||
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
|
@ -1,390 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert BLIP-2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
|
||||
# to make sure we can compare both original and HF implementation in float32
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
Blip2Config,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2ForImageTextRetrieval,
|
||||
Blip2Processor,
|
||||
Blip2QFormerConfig,
|
||||
Blip2VisionConfig,
|
||||
BlipImageProcessor,
|
||||
OPTConfig,
|
||||
T5Config,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, model_name):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
|
||||
if "itm" in model_name:
|
||||
rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
|
||||
rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
|
||||
rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
|
||||
rename_keys.append(("text_proj.weight", "text_projection.weight"))
|
||||
rename_keys.append(("text_proj.bias", "text_projection.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name, eos_token_id):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "opt-2.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "opt-6.7b" in model_name:
|
||||
text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
|
||||
elif "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "itm" in model_name:
|
||||
text_config = {}
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
if "itm" in model_name:
|
||||
config = Blip2Config(
|
||||
vision_config=vision_config,
|
||||
qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
|
||||
)
|
||||
else:
|
||||
config = Blip2Config(vision_config=vision_config, text_config=text_config)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(
|
||||
model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
if "opt" in model_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
|
||||
elif "itm" in model_name:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
||||
|
||||
if "itm" in model_name:
|
||||
eos_token_id = None
|
||||
else:
|
||||
eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
|
||||
config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
|
||||
|
||||
if "itm" in model_name:
|
||||
hf_model = Blip2ForImageTextRetrieval(config).eval()
|
||||
else:
|
||||
hf_model = Blip2ForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
|
||||
"blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
|
||||
"blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
|
||||
"blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
|
||||
"blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
|
||||
"blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
|
||||
"blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
|
||||
"blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
|
||||
"blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config, model_name)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "opt_proj" in key:
|
||||
key = key.replace("opt_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("opt"):
|
||||
key = key.replace("opt", "language")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert len(missing_keys) == 0
|
||||
|
||||
if "itm" in model_name:
|
||||
unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
|
||||
assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
|
||||
else:
|
||||
assert unexpected_keys == ["qformer.embeddings.position_ids"]
|
||||
|
||||
image = load_demo_image()
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
|
||||
if "itm" in model_name:
|
||||
caption = "a large fountain spewing water into the air"
|
||||
input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=True,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
|
||||
original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
|
||||
itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
|
||||
assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
with torch.no_grad():
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
|
||||
)
|
||||
logits = hf_model(
|
||||
pixel_values=pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_image_text_matching_head=False,
|
||||
)
|
||||
|
||||
assert original_logits.shape == logits.logits_per_image.shape
|
||||
print("First values of original logits:", original_logits[0, :3])
|
||||
print("First values of HF logits:", logits.logits_per_image[0, :3])
|
||||
|
||||
# assert values
|
||||
# cast to same type
|
||||
target_dtype = logits.logits_per_image.dtype
|
||||
assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
else:
|
||||
input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
with torch.no_grad():
|
||||
if "opt" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
|
||||
logits = hf_model(pixel_values, input_ids).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
|
||||
).logits
|
||||
labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(pixel_values, input_ids, labels=labels).logits
|
||||
|
||||
assert original_logits.shape == logits.shape
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating a caption...")
|
||||
prompt = "Question: what object is in this image? Answer:"
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
|
||||
|
||||
set_seed(42)
|
||||
|
||||
original_outputs = original_model.generate(
|
||||
{"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
|
||||
)
|
||||
outputs = hf_model.generate(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=5,
|
||||
max_length=30,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.0,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("Original generation:", original_outputs)
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"nielsr/{model_name}")
|
||||
hf_model.push_to_hub(f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
"blip2-itm-vit-g",
|
||||
"blip2-itm-vit-g-coco",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="blip2-opt-2.7b",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
# note: this script is tested on 2 GPUs, as models are compared in float32,
|
||||
# which requires quite some memory. Hence loading both on a
|
||||
# separate device is the easiest to compare
|
||||
parser.add_argument(
|
||||
"--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(
|
||||
args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
|
||||
)
|
@ -1239,6 +1239,9 @@ class Blip2TextEmbeddings(nn.Module):
|
||||
embeddings += position_embeddings
|
||||
|
||||
if query_embeds is not None:
|
||||
# `query_embeds` are kept in fp32 when we use it with Qformer
|
||||
if query_embeds.dtype != embeddings.dtype:
|
||||
query_embeds = query_embeds.to(embeddings.dtype)
|
||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||
else:
|
||||
embeddings = query_embeds
|
||||
@ -1386,6 +1389,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
# Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already
|
||||
if encoder_hidden_states.dtype != query_embeds.dtype:
|
||||
encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype)
|
||||
|
||||
if isinstance(encoder_hidden_states, list):
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
@ -1448,6 +1455,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
class Blip2Model(Blip2PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -1728,6 +1736,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_model_attention_mask = torch.ones(
|
||||
@ -1799,7 +1811,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -1898,7 +1910,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -2019,6 +2031,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -2191,6 +2204,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_model_attention_mask = torch.ones(
|
||||
@ -2312,6 +2329,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
@ -2371,7 +2392,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
@ -1,254 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print("Processing file: {}".format(file))
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors.keys():
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
@ -1,145 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Bros checkpoints."""
|
||||
|
||||
import argparse
|
||||
|
||||
import bros # original repo
|
||||
import torch
|
||||
|
||||
from transformers import BrosConfig, BrosModel, BrosProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_configs(model_name):
|
||||
bros_config = BrosConfig.from_pretrained(model_name)
|
||||
return bros_config
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"embeddings.bbox_sinusoid_emb.inv_freq",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if name == "embeddings.bbox_projection.weight":
|
||||
name = "bbox_embeddings.bbox_projection.weight"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
|
||||
|
||||
if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
|
||||
name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, model):
|
||||
# rename keys
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
# remove ignore keys
|
||||
remove_ignore_keys_(orig_state_dict)
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
# load original model
|
||||
original_model = bros.BrosModel.from_pretrained(model_name).eval()
|
||||
|
||||
# load HuggingFace Model
|
||||
bros_config = get_configs(model_name)
|
||||
model = BrosModel.from_pretrained(model_name, config=bros_config)
|
||||
model.eval()
|
||||
|
||||
state_dict = original_model.state_dict()
|
||||
new_state_dict = convert_state_dict(state_dict, model)
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# verify results
|
||||
|
||||
# original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
|
||||
bbox = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
|
||||
[0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
|
||||
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
processor = BrosProcessor.from_pretrained(model_name)
|
||||
|
||||
encoding = processor("His name is Rocco.", return_tensors="pt")
|
||||
encoding["bbox"] = bbox
|
||||
|
||||
original_hidden_states = original_model(**encoding).last_hidden_state
|
||||
# pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
last_hidden_states = model(**encoding).last_hidden_state
|
||||
|
||||
assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="jinho8345/bros-base-uncased",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Name of the original model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the converted model and processor to the 🤗 hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,59 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert T5 checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
@ -1,65 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert CANINE checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
|
||||
# Initialize PyTorch model
|
||||
config = CanineConfig()
|
||||
model = CanineModel(config)
|
||||
model.eval()
|
||||
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_canine(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model (weights and configuration)
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Save tokenizer files
|
||||
tokenizer = CanineTokenizer()
|
||||
print(f"Save tokenizer files to {pytorch_dump_path}")
|
||||
tokenizer.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a folder where the PyTorch model will be placed.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
|
@ -1,476 +0,0 @@
|
||||
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate import init_empty_weights
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ChameleonConfig,
|
||||
ChameleonForConditionalGeneration,
|
||||
ChameleonImageProcessor,
|
||||
ChameleonProcessor,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
|
||||
"Update your `tokenizers` library and re-run the tokenizer conversion."
|
||||
)
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
|
||||
|
||||
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
"""
|
||||
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"30B": 4,
|
||||
}
|
||||
|
||||
VOCAB_SIZE = 65536
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
input_model_path = os.path.join(input_base_path, "models", model_size.lower())
|
||||
params_path = os.path.join(input_model_path, "params.json")
|
||||
consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
|
||||
|
||||
params = read_json(params_path)
|
||||
if os.path.isfile(consolidate_params_path):
|
||||
params = {**params, **read_json(consolidate_params_path)}
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
model_parallel_size = params["model_parallel_size"]
|
||||
params = params.get("model", params)
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = params.get("rope_theta", 10000.0)
|
||||
swin_norm = params["swin_norm"]
|
||||
if base > 10000.0:
|
||||
max_position_embeddings = 16384
|
||||
else:
|
||||
# Depending on the Chameleon version, the default max_position_embeddings has different values.
|
||||
if chameleon_version == 1:
|
||||
max_position_embeddings = 4096
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Version {chameleon_version} of chameleon is not supported yet. "
|
||||
"Current supported versions of chameleon are [1]."
|
||||
)
|
||||
|
||||
if params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
||||
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
||||
key_value_dim = dim // num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
num_local_key_value_heads = n_heads_per_shard
|
||||
key_value_dim = dim
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = None
|
||||
for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
|
||||
possible_path = os.path.join(input_model_path, possible_name)
|
||||
if os.path.exists(possible_path):
|
||||
loaded = torch.load(possible_path, map_location="cpu")
|
||||
break
|
||||
assert loaded is not None
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
# Load weights to the state dict
|
||||
state_dict = {}
|
||||
for layer_i in range(n_layers):
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(n_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
|
||||
.view(dims_per_head // 2, 2)
|
||||
.t()
|
||||
.reshape(1, -1)
|
||||
.repeat_interleave(num_key_value_heads, 0)
|
||||
)
|
||||
|
||||
else:
|
||||
# Sharded
|
||||
state_dict.update(
|
||||
{
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
|
||||
[l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
|
||||
).mean(dim=0),
|
||||
}
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim),
|
||||
n_heads=n_heads,
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim),
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
)
|
||||
|
||||
# qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(n_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
|
||||
torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
|
||||
.view(num_shards, dims_per_head // 2, 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(num_shards, -1)
|
||||
.repeat_interleave(num_key_value_heads // num_shards, 0)
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
||||
num_local_key_value_heads, dims_per_head, dim
|
||||
)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(key_value_dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
if num_shards == 1:
|
||||
# Unsharded
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
state_dict.update(
|
||||
{
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
)
|
||||
|
||||
# Load VQGAN weights
|
||||
vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
|
||||
vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"]
|
||||
for k, v in vqgan_state_dict.items():
|
||||
if "decoder" in k:
|
||||
continue # we dont do image generation yet
|
||||
state_dict[f"model.vqmodel.{k}"] = v
|
||||
|
||||
# Write configs
|
||||
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
||||
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
|
||||
tokenizer_config = json.load(tokenizer_file)
|
||||
vocabulary_map = tokenizer_config["model"]["vocab"]
|
||||
vocabulary_map["<image>"] = vocabulary_map[
|
||||
"<reserved08707>"
|
||||
] # use a reserved token instead of adding a new one
|
||||
del vocabulary_map["<reserved08707>"]
|
||||
|
||||
for token in tokenizer_config["added_tokens"]:
|
||||
if token["content"] == "<reserved08707>":
|
||||
token["content"] = "<image>"
|
||||
|
||||
with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f) # save the new file to init tokenizer later
|
||||
|
||||
vq_keys_to_replace = [
|
||||
("ch", "base_channels"),
|
||||
("out_ch", "out_channels"),
|
||||
("n_embed", "num_embeddings"),
|
||||
("ch_mult", "channel_multiplier"),
|
||||
("double_z", "double_latent"),
|
||||
("z_channels", "latent_channels"),
|
||||
]
|
||||
with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
|
||||
vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
|
||||
vq_config.update(**vq_config["ddconfig"])
|
||||
for old, new in vq_keys_to_replace:
|
||||
vq_config[new] = vq_config[old]
|
||||
del vq_config["ddconfig"]
|
||||
del vq_config["ckpt_path"]
|
||||
del vq_config["lossconfig"]
|
||||
|
||||
config = ChameleonConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
rope_theta=base,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
model_parallel_size=model_parallel_size,
|
||||
swin_norm=swin_norm,
|
||||
vq_config=vq_config,
|
||||
vocabulary_map=vocabulary_map,
|
||||
)
|
||||
with init_empty_weights():
|
||||
model = ChameleonForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(state_dict, assign=True, strict=False)
|
||||
model.save_pretrained(model_path, safe_serialization=True)
|
||||
|
||||
# Load and save the processor
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
|
||||
)
|
||||
tokenizer.sep_token_id = 8710 # assign <reserved08706> to sep so that we can append it after input text
|
||||
tokenizer.pad_token_id = 1 # assing <pad> to special pad_token
|
||||
image_processor = ChameleonImageProcessor()
|
||||
processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
del loaded
|
||||
del vqgan_state_dict
|
||||
gc.collect()
|
||||
|
||||
# Short inference on a few examples to check if generation makes sense
|
||||
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
|
||||
print("Loading the checkpoint in a Chameleon model...")
|
||||
print("*" * 100)
|
||||
model = ChameleonForConditionalGeneration.from_pretrained(
|
||||
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
processor = ChameleonProcessor.from_pretrained(model_path)
|
||||
|
||||
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
|
||||
image = Image.open(
|
||||
requests.get(
|
||||
"https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
|
||||
).raw
|
||||
)
|
||||
inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
|
||||
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for single-image: {generated_text}")
|
||||
print("*" * 100)
|
||||
|
||||
# Multi-image example
|
||||
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
|
||||
image = Image.open(
|
||||
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
||||
)
|
||||
image_2 = Image.open(
|
||||
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
|
||||
)
|
||||
|
||||
inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
|
||||
length = inputs.input_ids.shape[1]
|
||||
out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
|
||||
generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
|
||||
|
||||
print(f"Generation for multi-image: {generated_text}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of Chameleon weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "30B"],
|
||||
help=""
|
||||
" models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://github.com/facebookresearch/chameleon",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_inference",
|
||||
action="store_true",
|
||||
help="Whether to load the model for generation to test it's converted correctly.",
|
||||
)
|
||||
# Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
|
||||
parser.add_argument(
|
||||
"--chameleon_version",
|
||||
choices=[1],
|
||||
default=1,
|
||||
type=int,
|
||||
help="Version of the Chameleon model to convert",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
model_size=args.model_size,
|
||||
chameleon_version=args.chameleon_version,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1289,13 +1289,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
raise ValueError(
|
||||
@ -1304,6 +1301,9 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
past_key_values = DynamicCache()
|
||||
|
@ -1,134 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ChineseCLIPConfig, ChineseCLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
|
||||
q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
|
||||
out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight.data = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias.data = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_weights, prefix):
|
||||
copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
|
||||
copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_weights, prefix):
|
||||
hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
|
||||
hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_weights, prefix):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
|
||||
copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_weights, prefix):
|
||||
for layer_id, hf_layer in enumerate(hf_layers):
|
||||
copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
|
||||
|
||||
# copy text encoder
|
||||
for name, param in hf_model.text_model.named_parameters():
|
||||
param.data = pt_weights[f"bert.{name}"].data
|
||||
|
||||
|
||||
def copy_vision_model_and_projection(hf_model, pt_weights):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
|
||||
|
||||
# copy embeddings
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
|
||||
hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
|
||||
assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
|
||||
config = ChineseCLIPConfig.from_pretrained(config_path)
|
||||
|
||||
hf_model = ChineseCLIPModel(config).eval()
|
||||
|
||||
pt_weights = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
||||
pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_weights)
|
||||
copy_vision_model_and_projection(hf_model, pt_weights)
|
||||
hf_model.logit_scale.data = pt_weights["logit_scale"].data
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output folder storing converted hf PyTorch model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
||||
print("The conversion is finished!")
|
@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
|
||||
from laion_clap import CLAP_Module
|
||||
|
||||
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
|
||||
|
||||
|
||||
KEYS_TO_MODIFY_MAPPING = {
|
||||
"text_branch": "text_model",
|
||||
"audio_branch": "audio_model.audio_encoder",
|
||||
"attn": "attention.self",
|
||||
"self.proj": "output.dense",
|
||||
"attention.self_mask": "attn_mask",
|
||||
"mlp.fc1": "intermediate.dense",
|
||||
"mlp.fc2": "output.dense",
|
||||
"norm1": "layernorm_before",
|
||||
"norm2": "layernorm_after",
|
||||
"bn0": "batch_norm",
|
||||
}
|
||||
|
||||
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
|
||||
|
||||
|
||||
def init_clap(checkpoint_path, model_type, enable_fusion=False):
|
||||
model = CLAP_Module(
|
||||
amodel=model_type,
|
||||
enable_fusion=enable_fusion,
|
||||
)
|
||||
model.load_ckpt(checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
def get_config_from_original(clap_model):
|
||||
audio_config = {
|
||||
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
|
||||
"depths": clap_model.model.audio_branch.depths,
|
||||
"hidden_size": clap_model.model.audio_projection[0].in_features,
|
||||
}
|
||||
|
||||
text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
|
||||
|
||||
return ClapConfig(audio_config=audio_config, text_config=text_config)
|
||||
|
||||
|
||||
def rename_state_dict(state_dict):
|
||||
model_state_dict = {}
|
||||
|
||||
sequential_layers_pattern = r".*sequential.(\d+).*"
|
||||
text_projection_pattern = r".*_projection.(\d+).*"
|
||||
|
||||
for key, value in state_dict.items():
|
||||
# check if any key needs to be modified
|
||||
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in key:
|
||||
key = key.replace(key_to_modify, new_key)
|
||||
|
||||
if re.match(sequential_layers_pattern, key):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
# Because in CLAP they use `nn.Sequential`...
|
||||
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
||||
|
||||
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
||||
|
||||
if "audio" and "qkv" in key:
|
||||
# split qkv into query key and value
|
||||
mixed_qkv = value
|
||||
qkv_dim = mixed_qkv.size(0) // 3
|
||||
|
||||
query_layer = mixed_qkv[:qkv_dim]
|
||||
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
||||
value_layer = mixed_qkv[qkv_dim * 2 :]
|
||||
|
||||
model_state_dict[key.replace("qkv", "query")] = query_layer
|
||||
model_state_dict[key.replace("qkv", "key")] = key_layer
|
||||
model_state_dict[key.replace("qkv", "value")] = value_layer
|
||||
else:
|
||||
model_state_dict[key] = value
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
|
||||
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
|
||||
|
||||
clap_model.eval()
|
||||
state_dict = clap_model.model.state_dict()
|
||||
state_dict = rename_state_dict(state_dict)
|
||||
|
||||
transformers_config = get_config_from_original(clap_model)
|
||||
transformers_config.audio_config.enable_fusion = enable_fusion
|
||||
model = ClapModel(transformers_config)
|
||||
|
||||
# ignore the spectrogram embedding layer
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
transformers_config.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
|
||||
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clap_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
|
||||
)
|
@ -1,156 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from clip import load
|
||||
|
||||
from transformers import CLIPConfig, CLIPModel
|
||||
|
||||
|
||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
|
||||
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
|
||||
|
||||
out_proj_weights = pt_attn_layer.out_proj.weight
|
||||
out_proj_bias = pt_attn_layer.out_proj.bias
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = q_proj
|
||||
hf_attn_layer.q_proj.bias.data = q_proj_bias
|
||||
|
||||
hf_attn_layer.k_proj.weight.data = k_proj
|
||||
hf_attn_layer.k_proj.bias.data = k_proj_bias
|
||||
|
||||
hf_attn_layer.v_proj.weight.data = v_proj
|
||||
hf_attn_layer.v_proj.bias.data = v_proj_bias
|
||||
|
||||
hf_attn_layer.out_proj.weight = out_proj_weights
|
||||
hf_attn_layer.out_proj.bias = out_proj_bias
|
||||
|
||||
|
||||
def copy_mlp(hf_mlp, pt_mlp):
|
||||
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
|
||||
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
|
||||
|
||||
|
||||
def copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
|
||||
def copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
|
||||
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
|
||||
|
||||
# copy MLP
|
||||
copy_mlp(hf_layer.mlp, pt_layer.mlp)
|
||||
|
||||
# copy attn
|
||||
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
|
||||
|
||||
|
||||
def copy_layers(hf_layers, pt_layers):
|
||||
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
|
||||
copy_layer(hf_layer, pt_layer)
|
||||
|
||||
|
||||
def copy_encoder(hf_encoder, pt_model):
|
||||
# copy embeds
|
||||
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
|
||||
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
|
||||
|
||||
# copy layer norm
|
||||
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
|
||||
|
||||
# copy hidden layers
|
||||
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
|
||||
|
||||
|
||||
def copy_text_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
|
||||
|
||||
# copy text encoder
|
||||
copy_encoder(hf_model.text_model, pt_model)
|
||||
|
||||
|
||||
def copy_vison_model_and_projection(hf_model, pt_model):
|
||||
# copy projection
|
||||
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
|
||||
|
||||
# copy layer norms
|
||||
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
|
||||
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
|
||||
|
||||
# copy embeds
|
||||
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
|
||||
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
|
||||
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
|
||||
|
||||
# copy encoder
|
||||
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = CLIPConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
|
||||
|
||||
hf_model = CLIPModel(config).eval()
|
||||
|
||||
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
|
||||
pt_model = pt_model.eval()
|
||||
|
||||
copy_text_model_and_projection(hf_model, pt_model)
|
||||
copy_vison_model_and_projection(hf_model, pt_model)
|
||||
hf_model.logit_scale = pt_model.logit_scale
|
||||
|
||||
# Use `eos_token` so the example is more meaningful
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[config.text_config.bos_token_id]
|
||||
+ list(range(3, 77))
|
||||
+ [config.text_config.eos_token_id]
|
||||
+ [config.text_config.pad_token_id]
|
||||
]
|
||||
)
|
||||
pixel_values = torch.randn(1, 3, 224, 224)
|
||||
|
||||
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
|
||||
hf_logits_per_image = hf_outputs.logits_per_image
|
||||
hf_logits_per_text = hf_outputs.logits_per_text
|
||||
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
|
||||
|
||||
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
|
||||
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
|
||||
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
|
@ -1,264 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
CLIPSegConfig,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPSegProcessor,
|
||||
CLIPSegTextConfig,
|
||||
CLIPSegVisionConfig,
|
||||
CLIPTokenizer,
|
||||
ViTImageProcessor,
|
||||
)
|
||||
|
||||
|
||||
def get_clipseg_config(model_name):
|
||||
text_config = CLIPSegTextConfig()
|
||||
vision_config = CLIPSegVisionConfig(patch_size=16)
|
||||
|
||||
use_complex_transposed_convolution = True if "refined" in model_name else False
|
||||
reduce_dim = 16 if "rd16" in model_name else 64
|
||||
|
||||
config = CLIPSegConfig.from_text_vision_configs(
|
||||
text_config,
|
||||
vision_config,
|
||||
use_complex_transposed_convolution=use_complex_transposed_convolution,
|
||||
reduce_dim=reduce_dim,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
# update prefixes
|
||||
if "clip_model" in name:
|
||||
name = name.replace("clip_model", "clip")
|
||||
if "transformer" in name:
|
||||
if "visual" in name:
|
||||
name = name.replace("visual.transformer", "vision_model")
|
||||
else:
|
||||
name = name.replace("transformer", "text_model")
|
||||
if "resblocks" in name:
|
||||
name = name.replace("resblocks", "encoder.layers")
|
||||
if "ln_1" in name:
|
||||
name = name.replace("ln_1", "layer_norm1")
|
||||
if "ln_2" in name:
|
||||
name = name.replace("ln_2", "layer_norm2")
|
||||
if "c_fc" in name:
|
||||
name = name.replace("c_fc", "fc1")
|
||||
if "c_proj" in name:
|
||||
name = name.replace("c_proj", "fc2")
|
||||
if "attn" in name and "self" not in name:
|
||||
name = name.replace("attn", "self_attn")
|
||||
# text encoder
|
||||
if "token_embedding" in name:
|
||||
name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
|
||||
if "positional_embedding" in name and "visual" not in name:
|
||||
name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
|
||||
if "ln_final" in name:
|
||||
name = name.replace("ln_final", "text_model.final_layer_norm")
|
||||
# vision encoder
|
||||
if "visual.class_embedding" in name:
|
||||
name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
|
||||
if "visual.conv1" in name:
|
||||
name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
|
||||
if "visual.positional_embedding" in name:
|
||||
name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
|
||||
if "visual.ln_pre" in name:
|
||||
name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
|
||||
if "visual.ln_post" in name:
|
||||
name = name.replace("visual.ln_post", "vision_model.post_layernorm")
|
||||
# projection layers
|
||||
if "visual.proj" in name:
|
||||
name = name.replace("visual.proj", "visual_projection.weight")
|
||||
if "text_projection" in name:
|
||||
name = name.replace("text_projection", "text_projection.weight")
|
||||
# decoder
|
||||
if "trans_conv" in name:
|
||||
name = name.replace("trans_conv", "transposed_convolution")
|
||||
if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
|
||||
name = "decoder." + name
|
||||
if "blocks" in name:
|
||||
name = name.replace("blocks", "decoder.layers")
|
||||
if "linear1" in name:
|
||||
name = name.replace("linear1", "mlp.fc1")
|
||||
if "linear2" in name:
|
||||
name = name.replace("linear2", "mlp.fc2")
|
||||
if "norm1" in name and "layer_" not in name:
|
||||
name = name.replace("norm1", "layer_norm1")
|
||||
if "norm2" in name and "layer_" not in name:
|
||||
name = name.replace("norm2", "layer_norm2")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def convert_state_dict(orig_state_dict, config):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if key.startswith("clip_model") and "attn.in_proj" in key:
|
||||
key_split = key.split(".")
|
||||
if "visual" in key:
|
||||
layer_num = int(key_split[4])
|
||||
dim = config.vision_config.hidden_size
|
||||
prefix = "vision_model"
|
||||
else:
|
||||
layer_num = int(key_split[3])
|
||||
dim = config.text_config.hidden_size
|
||||
prefix = "text_model"
|
||||
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
|
||||
dim : dim * 2, :
|
||||
]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
elif "self_attn" in key and "out_proj" not in key:
|
||||
key_split = key.split(".")
|
||||
layer_num = int(key_split[1])
|
||||
dim = config.reduce_dim
|
||||
if "weight" in key:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
|
||||
else:
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
|
||||
orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
|
||||
else:
|
||||
new_name = rename_key(key)
|
||||
if "visual_projection" in new_name or "text_projection" in new_name:
|
||||
val = val.T
|
||||
orig_state_dict[new_name] = val
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
return image
|
||||
|
||||
|
||||
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
|
||||
config = get_clipseg_config(model_name)
|
||||
model = CLIPSegForImageSegmentation(config)
|
||||
model.eval()
|
||||
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# remove some keys
|
||||
for key in state_dict.copy().keys():
|
||||
if key.startswith("model"):
|
||||
state_dict.pop(key, None)
|
||||
|
||||
# rename some keys
|
||||
state_dict = convert_state_dict(state_dict, config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
|
||||
raise ValueError("Missing keys that are not expected: {}".format(missing_keys))
|
||||
if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
|
||||
raise ValueError(f"Unexpected keys: {unexpected_keys}")
|
||||
|
||||
image_processor = ViTImageProcessor(size=352)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
||||
processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
|
||||
|
||||
image = prepare_img()
|
||||
text = ["a glass", "something to fill", "wood", "a jar"]
|
||||
|
||||
inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify values
|
||||
expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
|
||||
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
|
||||
if model_name == "clipseg-rd64-refined":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
|
||||
)
|
||||
elif model_name == "clipseg-rd64":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
|
||||
)
|
||||
elif model_name == "clipseg-rd16":
|
||||
expected_masks_slice = torch.tensor(
|
||||
[[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model name {model_name} not supported.")
|
||||
|
||||
assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
|
||||
assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
|
||||
assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and processor for {model_name} to the hub")
|
||||
model.push_to_hub(f"CIDAS/{model_name}")
|
||||
processor.push_to_hub(f"CIDAS/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="clipseg-rd64",
|
||||
type=str,
|
||||
choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
|
||||
help=(
|
||||
"Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
|
||||
" reduce dimension)"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
|
||||
" the decoder weights."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -1,234 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Weights conversion script for CLVP
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
|
||||
"decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
|
||||
}
|
||||
|
||||
dim = 1024
|
||||
sub_dim = dim // 16
|
||||
|
||||
CLVP_ENCODERS_MAPPING = {
|
||||
"text_transformer.transformer.attn_layers": "text_encoder_model",
|
||||
"speech_transformer.transformer.attn_layers": "speech_encoder_model",
|
||||
"text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
|
||||
"speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
|
||||
"to_text_latent": "text_encoder_model.projection",
|
||||
"to_speech_latent": "speech_encoder_model.projection",
|
||||
"text_emb": "text_encoder_model.token_embedding",
|
||||
"speech_emb": "speech_encoder_model.token_embedding",
|
||||
"1.wrap.net.0": "mlp.fc1",
|
||||
"1.wrap.net.3": "mlp.fc2",
|
||||
"1.wrap": "self_attn",
|
||||
"to_out": "out_proj",
|
||||
"to_q": "q_proj",
|
||||
"to_k": "k_proj",
|
||||
"to_v": "v_proj",
|
||||
"temperature": "logit_scale",
|
||||
}
|
||||
|
||||
CLVP_DECODER_MAPPING = {
|
||||
"conditioning_encoder.init": "conditioning_encoder.mel_conv",
|
||||
"conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
|
||||
"mel_attn_blocks": "group_norms",
|
||||
".norm.weight": ".weight",
|
||||
".norm.bias": ".bias",
|
||||
"text_embedding": "conditioning_encoder.text_token_embedding",
|
||||
"text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
|
||||
"final_norm": "speech_decoder_model.final_norm",
|
||||
"mel_head": "speech_decoder_model.lm_head",
|
||||
"gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
|
||||
"mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
|
||||
"mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
|
||||
"gpt.h": "speech_decoder_model.model.decoder.layers",
|
||||
"ln_1": "input_layernorm",
|
||||
"ln_2": "post_attention_layernorm",
|
||||
}
|
||||
|
||||
|
||||
def update_index(present_index):
|
||||
if present_index % 2 == 0:
|
||||
return int(present_index / 2)
|
||||
else:
|
||||
return int((present_index - 1) / 2)
|
||||
|
||||
|
||||
def convert_encoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
# for input_rmsnorm.weight and post_attention_rmsnorm.weight
|
||||
if "0.0.g" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
if int(present_index) % 2 == 0:
|
||||
updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
|
||||
else:
|
||||
updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
|
||||
|
||||
if "transformer.attn_layers.layers" in updated_key:
|
||||
present_index = updated_key.split(".")[4]
|
||||
updated_index = update_index(int(present_index))
|
||||
updated_key = updated_key.replace(
|
||||
f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
|
||||
)
|
||||
|
||||
for k, v in CLVP_ENCODERS_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def convert_decoder_weights(original_weights):
|
||||
converted_weights = {}
|
||||
original_weights_keys = sorted(original_weights.keys())
|
||||
for original_key in original_weights_keys:
|
||||
updated_key = original_key
|
||||
if len(updated_key.split(".")) > 3:
|
||||
index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
|
||||
|
||||
# for decoder attention
|
||||
if "attn.c_attn" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
|
||||
continue
|
||||
|
||||
if "attn.c_proj" in updated_key:
|
||||
converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
|
||||
original_weights[updated_key].squeeze(-1).T
|
||||
)
|
||||
continue
|
||||
|
||||
if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
|
||||
original_weights.pop(updated_key)
|
||||
continue
|
||||
|
||||
# conditional encoder attention
|
||||
if "qkv" in updated_key:
|
||||
if attr == "weight":
|
||||
slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
|
||||
else:
|
||||
slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
|
||||
|
||||
indices = torch.arange(dim)
|
||||
index1, index2, index3 = (
|
||||
indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
|
||||
)
|
||||
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index1], slice2[index3], slice3[index2]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index2], slice2[index1], slice3[index3]],
|
||||
axis=0,
|
||||
)
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
|
||||
[slice1[index3], slice2[index2], slice3[index1]],
|
||||
axis=0,
|
||||
)
|
||||
continue
|
||||
|
||||
if "proj_out" in updated_key:
|
||||
converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
|
||||
updated_key
|
||||
].squeeze(-1)
|
||||
continue
|
||||
|
||||
for k, v in CLVP_DECODER_MAPPING.items():
|
||||
if k in updated_key:
|
||||
updated_key = updated_key.replace(k, v)
|
||||
|
||||
converted_weights[updated_key] = original_weights.pop(original_key)
|
||||
|
||||
return converted_weights
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
|
||||
filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
force_filename=root,
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
|
||||
|
||||
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
|
||||
converted_checkpoint = {}
|
||||
|
||||
for each_model_name, each_model_url in _MODELS.items():
|
||||
each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
|
||||
if not os.path.exists(each_model_path):
|
||||
print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
|
||||
_download(url=each_model_url, root=each_model_path)
|
||||
|
||||
if each_model_name == "clvp":
|
||||
clvp_checkpoint = torch.load(each_model_path, map_location="cpu")
|
||||
else:
|
||||
decoder_checkpoint = torch.load(each_model_path, map_location="cpu")
|
||||
|
||||
# Converting the weights
|
||||
converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
|
||||
converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
|
||||
|
||||
config = ClvpConfig.from_pretrained("susnato/clvp_dev")
|
||||
model = ClvpModelForConditionalGeneration(config)
|
||||
|
||||
model.load_state_dict(converted_checkpoint, strict=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Model saved at {pytorch_dump_folder_path}!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model. (Please enter full path)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
|
@ -1,214 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert ColPali weights from the original repository to the HF model format.
|
||||
|
||||
Original repository: https://github.com/illuin-tech/colpali.
|
||||
|
||||
NOTE: This script was originally run using `torch==2.5.1` and with:
|
||||
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf-internal \
|
||||
--push_to_hub
|
||||
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.3-merged \
|
||||
--revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.3-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.colpali import ColPaliForRetrieval
|
||||
from transformers.models.colpali.configuration_colpali import ColPaliConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ORIGINAL_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
if key.startswith("custom_text_proj"):
|
||||
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
|
||||
if key.startswith("model."):
|
||||
new_key = key.replace("model.", "vlm.", 1)
|
||||
new_state_dict[new_key] = value
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
|
||||
directory_path = snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
allow_patterns=["*.safetensors"],
|
||||
)
|
||||
|
||||
original_state_dict = {}
|
||||
for path in glob.glob(f"{directory_path}/*"):
|
||||
if path.endswith(".safetensors"):
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
original_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
|
||||
if "lm_head.weight" not in original_state_dict:
|
||||
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
|
||||
"model.language_model.model.embed_tokens.weight"
|
||||
].clone()
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_colpali_weights_to_hf(
|
||||
model_id: str,
|
||||
output_dir: str,
|
||||
push_to_hub: bool,
|
||||
revision: Optional[str] = None,
|
||||
original_vlm_name_or_path: Optional[str] = None,
|
||||
):
|
||||
# Load the original model data
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
if original_vlm_name_or_path is not None:
|
||||
original_config._name_or_path = original_vlm_name_or_path
|
||||
if hasattr(original_config, "architectures"):
|
||||
delattr(original_config, "architectures")
|
||||
|
||||
original_state_dict = load_original_state_dict(model_id, revision=revision)
|
||||
|
||||
# Format the state_dict keys
|
||||
original_state_dict = rename_state_dict_keys(original_state_dict)
|
||||
|
||||
# Create the new config
|
||||
config = ColPaliConfig(
|
||||
vlm_config=original_config,
|
||||
embedding_dim=128, # hardcoded in the original model
|
||||
)
|
||||
config.model_type = "colpali"
|
||||
config.is_composition = False
|
||||
|
||||
# Load the untrained model
|
||||
model = ColPaliForRetrieval(config=config).to("cpu").eval()
|
||||
print("Created model with new config and randomly initialized weights")
|
||||
|
||||
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
|
||||
# There are two ways to set the model's dtype:
|
||||
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
|
||||
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
|
||||
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
|
||||
# the new weights' dtypes match the original model.
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(ORIGINAL_DTYPE)
|
||||
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
|
||||
|
||||
# Load the original weights
|
||||
model.load_state_dict(original_state_dict)
|
||||
print("Loaded original model weights")
|
||||
|
||||
# Tie the weights (following ColPali's `__init__`` step)
|
||||
if model.vlm.language_model._tied_weights_keys is not None:
|
||||
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
|
||||
|
||||
# Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
state_dict_keys_new = set(model.state_dict().keys())
|
||||
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
|
||||
if disjoint_keys:
|
||||
raise ValueError(f"Incompatible keys: {disjoint_keys}")
|
||||
|
||||
# Save the model
|
||||
if push_to_hub:
|
||||
model.push_to_hub(output_dir, private=True)
|
||||
print(f"Model pushed to the hub at `{output_dir}`")
|
||||
else:
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
model.save_pretrained(output_dir)
|
||||
print(f"Model saved to `{output_dir}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script converts the original ColPali model to the HF model format.
|
||||
|
||||
Example usage:
|
||||
```bash
|
||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
|
||||
--model_id vidore/colpali-v1.2-merged \
|
||||
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
|
||||
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
|
||||
--output_dir vidore/colpali-v1.2-hf \
|
||||
--push_to_hub
|
||||
```
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
help="Model ID of the original model to convert",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
help="Revision of the model to download",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_vlm_name_or_path",
|
||||
help="Name or path of the original VLM backbone model",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_colpali_weights_to_hf(
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
push_to_hub=args.push_to_hub,
|
||||
revision=args.revision,
|
||||
original_vlm_name_or_path=args.original_vlm_name_or_path,
|
||||
)
|
@ -1,324 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Conditional DETR checkpoints."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
ConditionalDetrConfig,
|
||||
ConditionalDetrForObjectDetection,
|
||||
ConditionalDetrForSegmentation,
|
||||
ConditionalDetrImageProcessor,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
rename_keys = []
|
||||
for i in range(6):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
|
||||
)
|
||||
)
|
||||
rename_keys.append(
|
||||
(
|
||||
f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
|
||||
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
|
||||
)
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
|
||||
|
||||
# q, k, v projections in self/cross-attention in decoder for conditional DETR
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
|
||||
)
|
||||
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
|
||||
)
|
||||
# rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
|
||||
)
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
|
||||
rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
|
||||
rename_keys.append(
|
||||
(f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
|
||||
)
|
||||
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
# for conditional DETR, also convert reference point head and query scale MLP
|
||||
rename_keys.extend(
|
||||
[
|
||||
("input_proj.weight", "input_projection.weight"),
|
||||
("input_proj.bias", "input_projection.bias"),
|
||||
("query_embed.weight", "query_position_embeddings.weight"),
|
||||
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
|
||||
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
|
||||
("class_embed.weight", "class_labels_classifier.weight"),
|
||||
("class_embed.bias", "class_labels_classifier.bias"),
|
||||
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
|
||||
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
|
||||
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
|
||||
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
|
||||
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
|
||||
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
|
||||
("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
|
||||
("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
|
||||
("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
|
||||
("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
|
||||
("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
|
||||
("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
|
||||
("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def rename_key(state_dict, old, new):
|
||||
val = state_dict.pop(old)
|
||||
state_dict[new] = val
|
||||
|
||||
|
||||
def rename_backbone_keys(state_dict):
|
||||
new_state_dict = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
if "backbone.0.body" in key:
|
||||
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
new_state_dict[key] = value
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def read_in_q_k_v(state_dict, is_panoptic=False):
|
||||
prefix = ""
|
||||
if is_panoptic:
|
||||
prefix = "conditional_detr."
|
||||
|
||||
# first: transformer encoder
|
||||
for i in range(6):
|
||||
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
|
||||
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
|
||||
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
|
||||
"""
|
||||
|
||||
# load default config
|
||||
config = ConditionalDetrConfig()
|
||||
# set backbone and dilation attributes
|
||||
if "resnet101" in model_name:
|
||||
config.backbone = "resnet101"
|
||||
if "dc5" in model_name:
|
||||
config.dilation = True
|
||||
is_panoptic = "panoptic" in model_name
|
||||
if is_panoptic:
|
||||
config.num_labels = 250
|
||||
else:
|
||||
config.num_labels = 91
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# load image processor
|
||||
format = "coco_panoptic" if is_panoptic else "coco_detection"
|
||||
image_processor = ConditionalDetrImageProcessor(format=format)
|
||||
|
||||
# prepare image
|
||||
img = prepare_img()
|
||||
encoding = image_processor(images=img, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
|
||||
# load original model from torch hub
|
||||
conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
|
||||
state_dict = conditional_detr.state_dict()
|
||||
# rename keys
|
||||
for src, dest in rename_keys:
|
||||
if is_panoptic:
|
||||
src = "conditional_detr." + src
|
||||
rename_key(state_dict, src, dest)
|
||||
state_dict = rename_backbone_keys(state_dict)
|
||||
# query, key and value matrices need special treatment
|
||||
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
prefix = "conditional_detr.model." if is_panoptic else "model."
|
||||
for key in state_dict.copy().keys():
|
||||
if is_panoptic:
|
||||
if (
|
||||
key.startswith("conditional_detr")
|
||||
and not key.startswith("class_labels_classifier")
|
||||
and not key.startswith("bbox_predictor")
|
||||
):
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr.model" + key[4:]] = val
|
||||
elif "class_labels_classifier" in key or "bbox_predictor" in key:
|
||||
val = state_dict.pop(key)
|
||||
state_dict["conditional_detr." + key] = val
|
||||
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
|
||||
continue
|
||||
else:
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
else:
|
||||
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
|
||||
# verify our conversion
|
||||
original_outputs = conditional_detr(pixel_values)
|
||||
outputs = model(pixel_values)
|
||||
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
|
||||
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
|
||||
if is_panoptic:
|
||||
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
|
||||
|
||||
# Save model and image processor
|
||||
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="conditional_detr_resnet50",
|
||||
type=str,
|
||||
help="Name of the CONDITIONAL_DETR model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
|
@ -1,57 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ConvBERT checkpoint."""
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
|
||||
conf = ConvBertConfig.from_json_file(convbert_config_file)
|
||||
model = ConvBertModel(conf)
|
||||
|
||||
model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
|
||||
tf_model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--convbert_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help=(
|
||||
"The config json file corresponding to the pre-trained ConvBERT model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)
|
@ -1,242 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ConvNext checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/facebookresearch/ConvNeXt"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_convnext_config(checkpoint_url):
|
||||
config = ConvNextConfig()
|
||||
|
||||
if "tiny" in checkpoint_url:
|
||||
depths = [3, 3, 9, 3]
|
||||
hidden_sizes = [96, 192, 384, 768]
|
||||
if "small" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [96, 192, 384, 768]
|
||||
if "base" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [128, 256, 512, 1024]
|
||||
if "large" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [192, 384, 768, 1536]
|
||||
if "xlarge" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [256, 512, 1024, 2048]
|
||||
|
||||
if "1k" in checkpoint_url:
|
||||
num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
expected_shape = (1, 1000)
|
||||
else:
|
||||
num_labels = 21841
|
||||
filename = "imagenet-22k-id2label.json"
|
||||
expected_shape = (1, 21841)
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
config.num_labels = num_labels
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
if "1k" not in checkpoint_url:
|
||||
# this dataset contains 21843 labels but the model only has 21841
|
||||
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
|
||||
del id2label[9205]
|
||||
del id2label[15027]
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.hidden_sizes = hidden_sizes
|
||||
config.depths = depths
|
||||
|
||||
return config, expected_shape
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "downsample_layers.0.0" in name:
|
||||
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
|
||||
if "downsample_layers.0.1" in name:
|
||||
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
|
||||
if "downsample_layers.1.0" in name:
|
||||
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
|
||||
if "downsample_layers.1.1" in name:
|
||||
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
|
||||
if "downsample_layers.2.0" in name:
|
||||
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
|
||||
if "downsample_layers.2.1" in name:
|
||||
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
|
||||
if "downsample_layers.3.0" in name:
|
||||
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
|
||||
if "downsample_layers.3.1" in name:
|
||||
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
|
||||
if "stages" in name and "downsampling_layer" not in name:
|
||||
# stages.0.0. for instance should be renamed to stages.0.layers.0.
|
||||
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
|
||||
if "stages" in name:
|
||||
name = name.replace("stages", "encoder.stages")
|
||||
if "norm" in name:
|
||||
name = name.replace("norm", "layernorm")
|
||||
if "gamma" in name:
|
||||
name = name.replace("gamma", "layer_scale_parameter")
|
||||
if "head" in name:
|
||||
name = name.replace("head", "classifier")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ConvNext structure.
|
||||
"""
|
||||
|
||||
# define ConvNext configuration based on URL
|
||||
config, expected_shape = get_convnext_config(checkpoint_url)
|
||||
# load original state_dict from URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
|
||||
# rename keys
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val
|
||||
# add prefix to all keys expect classifier head
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
if not key.startswith("classifier"):
|
||||
key = "convnext." + key
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
model = ConvNextForImageClassification(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# Check outputs on an image, prepared by ConvNextImageProcessor
|
||||
size = 224 if "224" in checkpoint_url else 384
|
||||
image_processor = ConvNextImageProcessor(size=size)
|
||||
pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values
|
||||
|
||||
logits = model(pixel_values).logits
|
||||
|
||||
# note: the logits below were obtained without center cropping
|
||||
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([0.4525, 0.7539, 0.0308])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth":
|
||||
expected_logits = torch.tensor([0.3561, 0.6350, -0.0384])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([0.4174, -0.0989, 0.1489])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth":
|
||||
expected_logits = torch.tensor([0.2513, -0.1349, -0.1613])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.2980, 0.3631, -0.1198])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.2963, 0.1227, 0.1723])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth":
|
||||
expected_logits = torch.tensor([1.7956, 0.8390, 0.2820])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth":
|
||||
expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth":
|
||||
expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth":
|
||||
expected_logits = torch.tensor([0.2681, 0.2365, 0.6246])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth":
|
||||
expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth":
|
||||
expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth":
|
||||
expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444])
|
||||
else:
|
||||
raise ValueError(f"Unknown URL: {checkpoint_url}")
|
||||
|
||||
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3)
|
||||
assert logits.shape == expected_shape
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
print("Pushing model to the hub...")
|
||||
model_name = "convnext"
|
||||
if "tiny" in checkpoint_url:
|
||||
model_name += "-tiny"
|
||||
elif "small" in checkpoint_url:
|
||||
model_name += "-small"
|
||||
elif "base" in checkpoint_url:
|
||||
model_name += "-base"
|
||||
elif "xlarge" in checkpoint_url:
|
||||
model_name += "-xlarge"
|
||||
elif "large" in checkpoint_url:
|
||||
model_name += "-large"
|
||||
if "224" in checkpoint_url:
|
||||
model_name += "-224"
|
||||
elif "384" in checkpoint_url:
|
||||
model_name += "-384"
|
||||
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
|
||||
model_name += "-22k"
|
||||
if "22k" in checkpoint_url and "1k" in checkpoint_url:
|
||||
model_name += "-22k-1k"
|
||||
|
||||
model.push_to_hub(
|
||||
repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
|
||||
organization="nielsr",
|
||||
commit_message="Add model",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
type=str,
|
||||
help="URL of the original ConvNeXT checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
@ -1,286 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert ConvNeXTV2 checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/facebookresearch/ConvNeXt"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_convnextv2_config(checkpoint_url):
|
||||
config = ConvNextV2Config()
|
||||
|
||||
if "atto" in checkpoint_url:
|
||||
depths = [2, 2, 6, 2]
|
||||
hidden_sizes = [40, 80, 160, 320]
|
||||
if "femto" in checkpoint_url:
|
||||
depths = [2, 2, 6, 2]
|
||||
hidden_sizes = [48, 96, 192, 384]
|
||||
if "pico" in checkpoint_url:
|
||||
depths = [2, 2, 6, 2]
|
||||
hidden_sizes = [64, 128, 256, 512]
|
||||
if "nano" in checkpoint_url:
|
||||
depths = [2, 2, 8, 2]
|
||||
hidden_sizes = [80, 160, 320, 640]
|
||||
if "tiny" in checkpoint_url:
|
||||
depths = [3, 3, 9, 3]
|
||||
hidden_sizes = [96, 192, 384, 768]
|
||||
if "base" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [128, 256, 512, 1024]
|
||||
if "large" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [192, 384, 768, 1536]
|
||||
if "huge" in checkpoint_url:
|
||||
depths = [3, 3, 27, 3]
|
||||
hidden_sizes = [352, 704, 1408, 2816]
|
||||
|
||||
num_labels = 1000
|
||||
filename = "imagenet-1k-id2label.json"
|
||||
expected_shape = (1, 1000)
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
config.num_labels = num_labels
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
config.hidden_sizes = hidden_sizes
|
||||
config.depths = depths
|
||||
|
||||
return config, expected_shape
|
||||
|
||||
|
||||
def rename_key(name):
|
||||
if "downsample_layers.0.0" in name:
|
||||
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
|
||||
if "downsample_layers.0.1" in name:
|
||||
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
|
||||
if "downsample_layers.1.0" in name:
|
||||
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
|
||||
if "downsample_layers.1.1" in name:
|
||||
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
|
||||
if "downsample_layers.2.0" in name:
|
||||
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
|
||||
if "downsample_layers.2.1" in name:
|
||||
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
|
||||
if "downsample_layers.3.0" in name:
|
||||
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
|
||||
if "downsample_layers.3.1" in name:
|
||||
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
|
||||
if "stages" in name and "downsampling_layer" not in name:
|
||||
# stages.0.0. for instance should be renamed to stages.0.layers.0.
|
||||
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
|
||||
if "gamma" in name:
|
||||
name = name.replace("gamma", "weight")
|
||||
if "beta" in name:
|
||||
name = name.replace("beta", "bias")
|
||||
if "stages" in name:
|
||||
name = name.replace("stages", "encoder.stages")
|
||||
if "norm" in name:
|
||||
name = name.replace("norm", "layernorm")
|
||||
if "head" in name:
|
||||
name = name.replace("head", "classifier")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
def convert_preprocessor(checkpoint_url):
|
||||
if "224" in checkpoint_url:
|
||||
size = 224
|
||||
crop_pct = 224 / 256
|
||||
elif "384" in checkpoint_url:
|
||||
size = 384
|
||||
crop_pct = None
|
||||
else:
|
||||
size = 512
|
||||
crop_pct = None
|
||||
|
||||
return ConvNextImageProcessor(
|
||||
size=size,
|
||||
crop_pct=crop_pct,
|
||||
image_mean=[0.485, 0.456, 0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
resample=PILImageResampling.BICUBIC,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
|
||||
"""
|
||||
print("Downloading original model from checkpoint...")
|
||||
# define ConvNeXTV2 configuration based on URL
|
||||
config, expected_shape = get_convnextv2_config(checkpoint_url)
|
||||
# load original state_dict from URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
|
||||
|
||||
print("Converting model parameters...")
|
||||
# rename keys
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
state_dict[rename_key(key)] = val
|
||||
# add prefix to all keys expect classifier head
|
||||
for key in state_dict.copy().keys():
|
||||
val = state_dict.pop(key)
|
||||
if not key.startswith("classifier"):
|
||||
key = "convnextv2." + key
|
||||
state_dict[key] = val
|
||||
|
||||
# load HuggingFace model
|
||||
model = ConvNextV2ForImageClassification(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# Check outputs on an image, prepared by ConvNextImageProcessor
|
||||
preprocessor = convert_preprocessor(checkpoint_url)
|
||||
inputs = preprocessor(images=prepare_img(), return_tensors="pt")
|
||||
logits = model(**inputs).logits
|
||||
|
||||
# note: the logits below were obtained without center cropping
|
||||
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
|
||||
expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
|
||||
expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
|
||||
expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
|
||||
expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
|
||||
expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
|
||||
expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
|
||||
expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
|
||||
expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
|
||||
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
|
||||
expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
|
||||
else:
|
||||
raise ValueError(f"Unknown URL: {checkpoint_url}")
|
||||
|
||||
assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
|
||||
assert logits.shape == expected_shape
|
||||
print("Model outputs match the original results!")
|
||||
|
||||
if save_model:
|
||||
print("Saving model to local...")
|
||||
# Create folder to save model
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
os.mkdir(pytorch_dump_folder_path)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
preprocessor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
model_name = "convnextv2"
|
||||
if "atto" in checkpoint_url:
|
||||
model_name += "-atto"
|
||||
if "femto" in checkpoint_url:
|
||||
model_name += "-femto"
|
||||
if "pico" in checkpoint_url:
|
||||
model_name += "-pico"
|
||||
if "nano" in checkpoint_url:
|
||||
model_name += "-nano"
|
||||
elif "tiny" in checkpoint_url:
|
||||
model_name += "-tiny"
|
||||
elif "base" in checkpoint_url:
|
||||
model_name += "-base"
|
||||
elif "large" in checkpoint_url:
|
||||
model_name += "-large"
|
||||
elif "huge" in checkpoint_url:
|
||||
model_name += "-huge"
|
||||
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
|
||||
model_name += "-22k"
|
||||
elif "22k" in checkpoint_url and "1k" in checkpoint_url:
|
||||
model_name += "-22k-1k"
|
||||
elif "1k" in checkpoint_url:
|
||||
model_name += "-1k"
|
||||
if "224" in checkpoint_url:
|
||||
model_name += "-224"
|
||||
elif "384" in checkpoint_url:
|
||||
model_name += "-384"
|
||||
elif "512" in checkpoint_url:
|
||||
model_name += "-512"
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing {model_name} to the hub...")
|
||||
model.push_to_hub(model_name)
|
||||
preprocessor.push_to_hub(model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt",
|
||||
type=str,
|
||||
help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default="model",
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_convnextv2_checkpoint(
|
||||
args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
|
||||
)
|
@ -1,362 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert CvT checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/microsoft/CvT"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
|
||||
|
||||
|
||||
def embeddings(idx):
|
||||
"""
|
||||
The function helps in renaming embedding layer weights.
|
||||
|
||||
Args:
|
||||
idx: stage number in original model
|
||||
"""
|
||||
embed = []
|
||||
embed.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
|
||||
f"stage{idx}.patch_embed.proj.weight",
|
||||
)
|
||||
)
|
||||
embed.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
|
||||
f"stage{idx}.patch_embed.proj.bias",
|
||||
)
|
||||
)
|
||||
embed.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
|
||||
f"stage{idx}.patch_embed.norm.weight",
|
||||
)
|
||||
)
|
||||
embed.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
|
||||
f"stage{idx}.patch_embed.norm.bias",
|
||||
)
|
||||
)
|
||||
return embed
|
||||
|
||||
|
||||
def attention(idx, cnt):
|
||||
"""
|
||||
The function helps in renaming attention block layers weights.
|
||||
|
||||
Args:
|
||||
idx: stage number in original model
|
||||
cnt: count of blocks in each stage
|
||||
"""
|
||||
attention_weights = []
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
|
||||
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj.weight",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(
|
||||
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
|
||||
f"stage{idx}.blocks.{cnt}.attn.proj.bias",
|
||||
)
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
|
||||
)
|
||||
attention_weights.append(
|
||||
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
|
||||
)
|
||||
return attention_weights
|
||||
|
||||
|
||||
def cls_token(idx):
|
||||
"""
|
||||
Function helps in renaming cls_token weights
|
||||
"""
|
||||
token = []
|
||||
token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
|
||||
return token
|
||||
|
||||
|
||||
def final():
|
||||
"""
|
||||
Function helps in renaming final classification layer
|
||||
"""
|
||||
head = []
|
||||
head.append(("layernorm.weight", "norm.weight"))
|
||||
head.append(("layernorm.bias", "norm.bias"))
|
||||
head.append(("classifier.weight", "head.weight"))
|
||||
head.append(("classifier.bias", "head.bias"))
|
||||
return head
|
||||
|
||||
|
||||
def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
|
||||
"""
|
||||
Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
|
||||
"""
|
||||
img_labels_file = "imagenet-1k-id2label.json"
|
||||
num_labels = 1000
|
||||
|
||||
repo_id = "huggingface/label-files"
|
||||
num_labels = num_labels
|
||||
id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text())
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
|
||||
id2label = id2label
|
||||
label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
|
||||
|
||||
# For depth size 13 (13 = 1+2+10)
|
||||
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
|
||||
config.depth = [1, 2, 10]
|
||||
|
||||
# For depth size 21 (21 = 1+4+16)
|
||||
elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
|
||||
config.depth = [1, 4, 16]
|
||||
|
||||
# For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
|
||||
else:
|
||||
config.depth = [2, 2, 20]
|
||||
config.num_heads = [3, 12, 16]
|
||||
config.embed_dim = [192, 768, 1024]
|
||||
|
||||
model = CvtForImageClassification(config)
|
||||
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
|
||||
image_processor.size["shortest_edge"] = image_size
|
||||
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
|
||||
|
||||
huggingface_weights = OrderedDict()
|
||||
list_of_state_dict = []
|
||||
|
||||
for idx in range(len(config.depth)):
|
||||
if config.cls_token[idx]:
|
||||
list_of_state_dict = list_of_state_dict + cls_token(idx)
|
||||
list_of_state_dict = list_of_state_dict + embeddings(idx)
|
||||
for cnt in range(config.depth[idx]):
|
||||
list_of_state_dict = list_of_state_dict + attention(idx, cnt)
|
||||
|
||||
list_of_state_dict = list_of_state_dict + final()
|
||||
for gg in list_of_state_dict:
|
||||
print(gg)
|
||||
for i in range(len(list_of_state_dict)):
|
||||
huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
|
||||
|
||||
model.load_state_dict(huggingface_weights)
|
||||
model.save_pretrained(pytorch_dump_folder)
|
||||
image_processor.save_pretrained(pytorch_dump_folder)
|
||||
|
||||
|
||||
# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--cvt_model",
|
||||
default="cvt-w24",
|
||||
type=str,
|
||||
help="Name of the cvt model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=384,
|
||||
type=int,
|
||||
help="Input Image Size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cvt_file_name",
|
||||
default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
|
||||
type=str,
|
||||
help="Input Image Size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)
|
@ -1,233 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert DAB-DETR checkpoints."""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
|
||||
# for dab-DETR, also convert reference point head and query scale MLP
|
||||
r"input_proj\.(bias|weight)": r"input_projection.\1",
|
||||
r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
|
||||
r"class_embed\.(bias|weight)": r"class_embed.\1",
|
||||
# negative lookbehind because of the overlap
|
||||
r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
|
||||
r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
|
||||
r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
|
||||
r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
|
||||
r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
|
||||
r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
|
||||
r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
|
||||
r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
|
||||
# output projection
|
||||
r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
|
||||
# FFN layers
|
||||
r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
|
||||
# normalization layers
|
||||
# nm1
|
||||
r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
|
||||
# nm2
|
||||
r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
|
||||
# activation function weight
|
||||
r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
|
||||
#########################################################################################################################################
|
||||
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
|
||||
r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
|
||||
# FFNs
|
||||
r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
|
||||
# nm1
|
||||
r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
|
||||
# nm2
|
||||
r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
|
||||
# nm3
|
||||
r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
|
||||
# activation function weight
|
||||
r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
|
||||
# q, k, v projections and biases in self-attention in decoder
|
||||
r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
|
||||
# q, k, v projections in cross-attention in decoder
|
||||
r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
|
||||
r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
|
||||
logger.info("Converting image processor...")
|
||||
format = "coco_detection"
|
||||
image_processor = ConditionalDetrImageProcessor(format=format)
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
||||
# load modified config. Why? After loading the default config, the backbone kwargs are already set.
|
||||
if "dc5" in model_name:
|
||||
config = DabDetrConfig(dilation=True)
|
||||
else:
|
||||
# load default config
|
||||
config = DabDetrConfig()
|
||||
# set other attributes
|
||||
if "dab-detr-resnet-50-dc5" == model_name:
|
||||
config.temperature_height = 10
|
||||
config.temperature_width = 10
|
||||
if "fixxy" in model_name:
|
||||
config.random_refpoints_xy = True
|
||||
if "pat3" in model_name:
|
||||
config.num_patterns = 3
|
||||
# only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
|
||||
|
||||
config.num_labels = 91
|
||||
repo_id = "huggingface/label-files"
|
||||
filename = "coco-detection-id2label.json"
|
||||
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
# load original model from local path
|
||||
loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"))["model"]
|
||||
# Renaming the original model state dictionary to HF compatibile
|
||||
all_keys = list(loaded.keys())
|
||||
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||
state_dict = {}
|
||||
for key in all_keys:
|
||||
if "backbone.0.body" in key:
|
||||
new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
|
||||
state_dict[new_key] = loaded[key]
|
||||
# Q, K, V encoder values mapping
|
||||
elif re.search("self_attn.in_proj_(weight|bias)", key):
|
||||
# Dynamically find the layer number
|
||||
pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
layer_num = match.group(1)
|
||||
else:
|
||||
raise ValueError(f"Pattern not found in key: {key}")
|
||||
|
||||
in_proj_value = loaded.pop(key)
|
||||
if "weight" in key:
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
|
||||
elif "bias" in key:
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
|
||||
state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
|
||||
else:
|
||||
new_key = new_keys[key]
|
||||
state_dict[new_key] = loaded[key]
|
||||
|
||||
del loaded
|
||||
gc.collect()
|
||||
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
|
||||
prefix = "model."
|
||||
for key in state_dict.copy().keys():
|
||||
if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
|
||||
val = state_dict.pop(key)
|
||||
state_dict[prefix + key] = val
|
||||
# finally, create HuggingFace model and load state dict
|
||||
model = DabDetrForObjectDetection(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
model.push_to_hub(repo_id=model_name, commit_message="Add new model")
|
||||
|
||||
|
||||
def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
|
||||
logger.info("Converting image processor...")
|
||||
write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
|
||||
|
||||
logger.info(f"Converting model {model_name}...")
|
||||
write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="dab-detr-resnet-50",
|
||||
type=str,
|
||||
help="Name of the DAB_DETR model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_weights_path",
|
||||
default="modelzoo/R50/checkpoint.pth",
|
||||
type=str,
|
||||
help="The path of the original model weights like: modelzoo/checkpoint.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_dab_detr_checkpoint(
|
||||
args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
|
||||
)
|
@ -1,261 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import fnmatch
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
DacConfig,
|
||||
DacFeatureExtractor,
|
||||
DacModel,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
# checkpoints downloaded using:
|
||||
# pip install descript-audio-codec
|
||||
# python3 -m dac download # downloads the default 44kHz variant
|
||||
# python3 -m dac download --model_type 44khz # downloads the 44kHz variant
|
||||
# python3 -m dac download --model_type 24khz # downloads the 24kHz variant
|
||||
# python3 -m dac download --model_type 16khz # downloads the 16kHz variant
|
||||
# More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.models.dac")
|
||||
|
||||
|
||||
def match_pattern(string, pattern):
|
||||
# Split the pattern into parts
|
||||
pattern_parts = pattern.split(".")
|
||||
string_parts = string.split(".")
|
||||
|
||||
pattern_block_count = string_block_count = 0
|
||||
|
||||
for part in pattern_parts:
|
||||
if part.startswith("block"):
|
||||
pattern_block_count += 1
|
||||
|
||||
for part in string_parts:
|
||||
if part.startswith("block"):
|
||||
string_block_count += 1
|
||||
|
||||
return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
|
||||
|
||||
|
||||
TOP_LEVEL_KEYS = []
|
||||
IGNORE_KEYS = []
|
||||
|
||||
|
||||
MAPPING_ENCODER = {
|
||||
"encoder.block.0": ["encoder.conv1"],
|
||||
"encoder.block.5": ["encoder.snake1"],
|
||||
"encoder.block.6": ["encoder.conv2"],
|
||||
"encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
|
||||
"encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
|
||||
"encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
|
||||
"encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
|
||||
"encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
|
||||
"encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
|
||||
}
|
||||
|
||||
MAPPING_QUANTIZER = {
|
||||
"quantizer.quantizers.*": ["quantizer.quantizers.*"],
|
||||
}
|
||||
|
||||
MAPPING_DECODER = {
|
||||
"decoder.model.0": ["decoder.conv1"],
|
||||
"decoder.model.5": ["decoder.snake1"],
|
||||
"decoder.model.6": ["decoder.conv2"],
|
||||
"decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
|
||||
"decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
|
||||
"decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
|
||||
"decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
|
||||
"decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
|
||||
"decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
|
||||
}
|
||||
|
||||
|
||||
MAPPING = {
|
||||
**MAPPING_ENCODER,
|
||||
**MAPPING_QUANTIZER,
|
||||
**MAPPING_DECODER,
|
||||
}
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
if hf_shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
||||
f" {value.shape} for {full_name}"
|
||||
)
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
elif weight_type == "weight_g":
|
||||
hf_pointer.weight_g.data = value
|
||||
elif weight_type == "weight_v":
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
elif weight_type == "alpha":
|
||||
hf_pointer.alpha.data = value
|
||||
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def should_ignore(name, ignore_keys):
|
||||
for key in ignore_keys:
|
||||
if key.endswith(".*"):
|
||||
if name.startswith(key[:-1]):
|
||||
return True
|
||||
elif ".*." in key:
|
||||
prefix, suffix = key.split(".*.")
|
||||
if prefix in name and suffix in name:
|
||||
return True
|
||||
elif key in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def recursively_load_weights(orig_dict, hf_model, model_name):
|
||||
unused_weights = []
|
||||
|
||||
if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
for name, value in orig_dict.items():
|
||||
is_used = False
|
||||
for key, mapped_key in MAPPING.items():
|
||||
regex = re.compile(key)
|
||||
if regex.search(name):
|
||||
if len(mapped_key) == 1:
|
||||
if mapped_key[0][0] == "q":
|
||||
mapped_key = ".".join(name.split(".")[:-1])
|
||||
else:
|
||||
mapped_key = mapped_key[0]
|
||||
elif len(mapped_key) == 3:
|
||||
integers = re.findall(r"\b\d+\b", name)
|
||||
if mapped_key[0][0] == "d":
|
||||
mapped_key = "{}.{}.{}{}.{}".format(
|
||||
mapped_key[0],
|
||||
str(int(integers[0]) - 1),
|
||||
mapped_key[1],
|
||||
str(int(integers[1]) - 1),
|
||||
mapped_key[2],
|
||||
)
|
||||
else:
|
||||
mapped_key = "{}.{}.{}{}.{}".format(
|
||||
mapped_key[0],
|
||||
str(int(integers[0]) - 1),
|
||||
mapped_key[1],
|
||||
str(int(integers[1]) + 1),
|
||||
mapped_key[2],
|
||||
)
|
||||
elif len(mapped_key) == 2:
|
||||
integers = re.findall(r"\b\d+\b", name)
|
||||
mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
|
||||
|
||||
is_used = True
|
||||
if "weight_g" in name:
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "alpha" in name:
|
||||
weight_type = "alpha"
|
||||
elif "weight" in name:
|
||||
weight_type = "weight"
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
print(list(set(unused_weights)))
|
||||
|
||||
logger.warning(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_checkpoint(
|
||||
model_name,
|
||||
checkpoint_path,
|
||||
pytorch_dump_folder_path,
|
||||
sample_rate=16000,
|
||||
repo_id=None,
|
||||
):
|
||||
model_dict = torch.load(checkpoint_path, "cpu")
|
||||
|
||||
config = DacConfig()
|
||||
|
||||
metadata = model_dict["metadata"]["kwargs"]
|
||||
config.encoder_hidden_size = metadata["encoder_dim"]
|
||||
config.downsampling_ratios = metadata["encoder_rates"]
|
||||
config.codebook_size = metadata["codebook_size"]
|
||||
config.n_codebooks = metadata["n_codebooks"]
|
||||
config.codebook_dim = metadata["codebook_dim"]
|
||||
config.decoder_hidden_size = metadata["decoder_dim"]
|
||||
config.upsampling_ratios = metadata["decoder_rates"]
|
||||
config.quantizer_dropout = float(metadata["quantizer_dropout"])
|
||||
config.sampling_rate = sample_rate
|
||||
|
||||
model = DacModel(config)
|
||||
feature_extractor = DacFeatureExtractor()
|
||||
feature_extractor.sampling_rate = sample_rate
|
||||
|
||||
original_checkpoint = model_dict["state_dict"]
|
||||
|
||||
model.apply_weight_norm()
|
||||
recursively_load_weights(original_checkpoint, model, model_name)
|
||||
model.remove_weight_norm()
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if repo_id:
|
||||
print("Pushing to the hub...")
|
||||
feature_extractor.push_to_hub(repo_id)
|
||||
model.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="dac_44khz",
|
||||
type=str,
|
||||
help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
|
||||
)
|
||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
||||
)
|
||||
parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_checkpoint(
|
||||
args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
|
||||
)
|
@ -1,285 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Wav2Vec2 checkpoint."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from functools import reduce
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import Wav2Vec2Processor, logging
|
||||
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig
|
||||
|
||||
# Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
|
||||
from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy # noqa: F401
|
||||
from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"models.0.layer_norm": "feature_projection.layer_norm",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
||||
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
TOP_LEVEL_KEYS = [
|
||||
"lm_head",
|
||||
]
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
if hf_shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
||||
f" {value.shape} for {full_name}"
|
||||
)
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
elif weight_type == "weight_g":
|
||||
hf_pointer.weight_g.data = value
|
||||
elif weight_type == "weight_v":
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_headless):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
if not is_headless:
|
||||
feature_extractor = hf_model.data2vec_audio.feature_extractor
|
||||
pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed
|
||||
|
||||
else:
|
||||
feature_extractor = hf_model.feature_extractor
|
||||
pos_conv_embedding = hf_model.encoder.pos_conv_embed
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
if "conv_layers" in name:
|
||||
load_conv_layer(
|
||||
name,
|
||||
value,
|
||||
feature_extractor,
|
||||
unused_weights,
|
||||
)
|
||||
is_used = True
|
||||
elif "pos_conv" in name:
|
||||
load_pos_conv_layer(
|
||||
name,
|
||||
value,
|
||||
pos_conv_embedding,
|
||||
unused_weights,
|
||||
)
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
if not is_headless:
|
||||
mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
|
||||
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
mapped_key = mapped_key.replace("*", layer_index)
|
||||
if "weight_g" in name:
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "weight" in name:
|
||||
# TODO: don't match quantizer.weight_proj
|
||||
weight_type = "weight"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
continue
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
logger.warning(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
def access_by_string(module, path):
|
||||
names = path.split(".")
|
||||
return reduce(getattr, names, module)
|
||||
|
||||
|
||||
def set_weights(full_name, module, fsq_value, hf_weight_path):
|
||||
hf_weight = access_by_string(module, hf_weight_path)
|
||||
hf_value = hf_weight.data
|
||||
|
||||
if fsq_value.shape != hf_value.shape:
|
||||
raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.")
|
||||
hf_weight.data = fsq_value
|
||||
logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.")
|
||||
|
||||
|
||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights):
|
||||
name = full_name.split("conv_layers.")[-1]
|
||||
items = name.split(".")
|
||||
layer_id = int(items[0])
|
||||
type_id = int(items[1])
|
||||
|
||||
weight_type = name.split(".")[-1]
|
||||
if type_id == 0:
|
||||
layer_type = "conv"
|
||||
elif type_id == 2:
|
||||
layer_type = "layer_norm"
|
||||
else:
|
||||
unused_weights.append(full_name)
|
||||
return
|
||||
|
||||
set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}")
|
||||
|
||||
|
||||
def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights):
|
||||
name = full_name.split("pos_conv.")[-1]
|
||||
items = name.split(".")
|
||||
layer_id = int(items[0])
|
||||
type_id = int(items[1])
|
||||
|
||||
weight_type = name.split(".")[-1]
|
||||
if type_id != 0:
|
||||
unused_weights.append(full_name)
|
||||
return
|
||||
else:
|
||||
layer_type = "conv"
|
||||
|
||||
set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_wav2vec2_checkpoint(
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = Data2VecAudioConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = Data2VecAudioConfig()
|
||||
|
||||
if not is_finetuned:
|
||||
# Modify final_proj layer name
|
||||
hf_wav2vec = Data2VecAudioModel(config)
|
||||
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
|
||||
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
|
||||
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
|
||||
torch.save(state_dict, converted_ckpt)
|
||||
else:
|
||||
hf_wav2vec = Data2VecAudioForCTC(config)
|
||||
converted_ckpt = checkpoint_path
|
||||
|
||||
def load_data2vec(path):
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
|
||||
return model[0].eval()
|
||||
|
||||
model = load_data2vec(converted_ckpt)
|
||||
|
||||
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
||||
input_audio = [x["array"] for x in ds[:4]["audio"]]
|
||||
|
||||
inputs = processor(input_audio, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values
|
||||
attention_mask = inputs.attention_mask
|
||||
# input_values = inputs.input_values[:, :-1]
|
||||
# attention_mask = inputs.attention_mask[:, :-1]
|
||||
|
||||
hf_wav2vec.eval()
|
||||
model.eval()
|
||||
if is_finetuned:
|
||||
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
|
||||
"encoder_out"
|
||||
].transpose(0, 1)
|
||||
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"]
|
||||
|
||||
pred_ids = torch.argmax(our_output, dim=-1)
|
||||
output_string = processor.batch_decode(pred_ids)
|
||||
|
||||
print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}")
|
||||
else:
|
||||
their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
|
||||
"layer_results"
|
||||
][-1][0].transpose(0, 1)
|
||||
our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"]
|
||||
|
||||
print(our_output.shape, their_output.shape)
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
print("Do both models output the same tensors?", "🔥" if success else "💩")
|
||||
if not success:
|
||||
raise Exception("Something went wRoNg")
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if is_finetuned:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
else:
|
||||
processor.feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user