mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-31 09:04:37 +08:00 
			
		
		
		
	Compare commits
	
		
			15 Commits
		
	
	
		
			feat/conti
			...
			v4.50.3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a78e884f31 | |||
| e9a5e32b76 | |||
| 556b96d2e7 | |||
| f7ba365881 | |||
| b258dc35d5 | |||
| f4cfe5df33 | |||
| cfef91d802 | |||
| 6311953dd4 | |||
| 897130524b | |||
| d9ccb9adbb | |||
| e6ab93e702 | |||
| 650f607840 | |||
| 9abbb92297 | |||
| 0b057e66b5 | |||
| 26fbd6919a | 
| @ -105,59 +105,75 @@ inputs = processor.apply_chat_template( | |||||||
|     add_generation_prompt=True, |     add_generation_prompt=True, | ||||||
| ).to(model.device) | ).to(model.device) | ||||||
|  |  | ||||||
| output = model.generate(**inputs, max_new_tokens=50) | output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static") | ||||||
| print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ]) | print(processor.decode(output[0], skip_special_tokens=True)) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Multi-image Inference | Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to. | ||||||
|  |  | ||||||
| ```python | ```py | ||||||
| model_id = "google/gemma-3-4b-it" | from transformers.utils.attention_visualizer import AttentionMaskVisualizer | ||||||
| model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") |  | ||||||
| processor = AutoProcessor.from_pretrained(model_id, padding_side="left") |  | ||||||
|  |  | ||||||
| url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" |  | ||||||
| url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg" |  | ||||||
| messages = [ |  | ||||||
|     { |  | ||||||
|         "role": "system", |  | ||||||
|         "content": [ |  | ||||||
|             {"type": "text", "text": "You are a helpful assistant."} |  | ||||||
|         ] |  | ||||||
|     }, |  | ||||||
|     { |  | ||||||
|         "role": "user", "content": [ |  | ||||||
|             {"type": "image", "url": url_cow}, |  | ||||||
|             {"type": "image", "url": url_stop}, |  | ||||||
|             {"type": "text", "text": "Are these two images identical?"}, |  | ||||||
|         ] |  | ||||||
|     }, |  | ||||||
| ] |  | ||||||
| inputs = processor.apply_chat_template( |  | ||||||
|     messages, |  | ||||||
|     tokenize=True, |  | ||||||
|     return_dict=True, |  | ||||||
|     return_tensors="pt", |  | ||||||
|     add_generation_prompt=True, |  | ||||||
| ).to(model.device) |  | ||||||
|  |  | ||||||
| output = model.generate(**inputs, max_new_tokens=50) |  | ||||||
| print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ]) |  | ||||||
|  |  | ||||||
|  | visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it") | ||||||
|  | visualizer("<img>What is shown in this image?") | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Text-only inference | ## Notes | ||||||
|  |  | ||||||
| You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities. | - Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs. | ||||||
| ```python | - Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images. | ||||||
| from transformers import AutoTokenizer, Gemma3ForCausalLM |  | ||||||
|  |  | ||||||
| model_id = "google/gemma-3-1b-it" |     ```py | ||||||
|  |     url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" | ||||||
|  |     url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" | ||||||
|  |  | ||||||
| tokenizer = AutoTokenizer.from_pretrained(model_id) |     messages =[ | ||||||
| model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto") |         { | ||||||
|  |             "role": "system", | ||||||
|  |             "content": [ | ||||||
|  |                 {"type": "text", "text": "You are a helpful assistant."} | ||||||
|  |             ] | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |             "role": "user", | ||||||
|  |             "content": [ | ||||||
|  |                 {"type": "image", "url": url_cow}, | ||||||
|  |                 {"type": "image", "url": url_cat}, | ||||||
|  |                 {"type": "text", "text": "Which image is cuter?"}, | ||||||
|  |             ] | ||||||
|  |         }, | ||||||
|  |     ] | ||||||
|  |     ``` | ||||||
|  | - Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted. | ||||||
|  | - The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs. | ||||||
|  | - By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference. | ||||||
|  |  | ||||||
| input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device) |     ```diff | ||||||
|  |     inputs = processor.apply_chat_template( | ||||||
|  |         messages, | ||||||
|  |         tokenize=True, | ||||||
|  |         return_dict=True, | ||||||
|  |         return_tensors="pt", | ||||||
|  |         add_generation_prompt=True, | ||||||
|  |     +   do_pan_and_scan=True, | ||||||
|  |         ).to("cuda") | ||||||
|  |     ``` | ||||||
|  | - For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead. | ||||||
|  |  | ||||||
|  |     ```py | ||||||
|  |     import torch | ||||||
|  |     from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|  |  | ||||||
|  |     tokenizer = AutoTokenizer.from_pretrained( | ||||||
|  |         "google/gemma-3-1b-pt", | ||||||
|  |     ) | ||||||
|  |     model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |         "google/gemma-3-1b-pt", | ||||||
|  |         torch_dtype=torch.bfloat16, | ||||||
|  |         device_map="auto", | ||||||
|  |         attn_implementation="sdpa" | ||||||
|  |     ) | ||||||
|  |     input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda") | ||||||
|  |  | ||||||
| outputs = model.generate(**input_ids, max_new_tokens=100) | outputs = model.generate(**input_ids, max_new_tokens=100) | ||||||
| text = tokenizer.batch_decode(outputs, skip_special_tokens=True) | text = tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||||||
|  | |||||||
| @ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| Array = Any | Array = Any | ||||||
| Dataset = datasets.arrow_dataset.Dataset | Dataset = datasets.arrow_dataset.Dataset | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risk. | # Will error if the minimal version of Transformers is not installed. Remove at your own risk. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| Array = Any | Array = Any | ||||||
| Dataset = datasets.arrow_dataset.Dataset | Dataset = datasets.arrow_dataset.Dataset | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -45,7 +45,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") | require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -43,7 +43,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used. | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -58,7 +58,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| # You should update this to your particular problem to have better documentation of `model_type` | # You should update this to your particular problem to have better documentation of `model_type` | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logging.basicConfig(level=logging.INFO) | logging.basicConfig(level=logging.INFO) | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -48,7 +48,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -49,7 +49,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") | ||||||
|  | |||||||
| @ -52,7 +52,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt") | ||||||
|  | |||||||
| @ -51,7 +51,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version( | require_version( | ||||||
|     "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" |     "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt" | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from transformers.utils.versions import require_version | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -50,7 +50,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError): | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  | |||||||
| @ -53,7 +53,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| # region Checking dependencies | # region Checking dependencies | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
| @ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry | |||||||
|  |  | ||||||
|  |  | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| task_to_keys = { | task_to_keys = { | ||||||
|     "cola": ("sentence", None), |     "cola": ("sentence", None), | ||||||
|  | |||||||
| @ -56,7 +56,7 @@ from transformers.utils.versions import require_version | |||||||
|  |  | ||||||
| # region Dependencies and constants | # region Dependencies and constants | ||||||
| # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. | ||||||
| check_min_version("4.50.0.dev0") | check_min_version("4.50.0") | ||||||
|  |  | ||||||
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -446,7 +446,7 @@ install_requires = [ | |||||||
|  |  | ||||||
| setup( | setup( | ||||||
|     name="transformers", |     name="transformers", | ||||||
|     version="4.50.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) |     version="4.50.3",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) | ||||||
|     author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", |     author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)", | ||||||
|     author_email="transformers@huggingface.co", |     author_email="transformers@huggingface.co", | ||||||
|     description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", |     description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow", | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ | |||||||
| # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names | # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names | ||||||
| # in the namespace without actually importing anything (and especially none of the backends). | # in the namespace without actually importing anything (and especially none of the backends). | ||||||
|  |  | ||||||
| __version__ = "4.50.0.dev0" | __version__ = "4.50.3" | ||||||
|  |  | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
|  |  | ||||||
|  | |||||||
| @ -585,7 +585,7 @@ def _flatten_dynamic_cache_for_fx(cache, spec): | |||||||
|     return torch.utils._pytree.tree_flatten(dictionary)[0] |     return torch.utils._pytree.tree_flatten(dictionary)[0] | ||||||
|  |  | ||||||
|  |  | ||||||
| if is_torch_greater_or_equal("2.2"): | if is_torch_greater_or_equal("2.3"): | ||||||
|     torch.utils._pytree.register_pytree_node( |     torch.utils._pytree.register_pytree_node( | ||||||
|         DynamicCache, |         DynamicCache, | ||||||
|         _flatten_dynamic_cache, |         _flatten_dynamic_cache, | ||||||
| @ -611,21 +611,29 @@ class OffloadedCache(DynamicCache): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self) -> None: |     def __init__(self) -> None: | ||||||
|         if not (torch.cuda.is_available() or (is_torch_greater_or_equal("2.7") and torch.xpu.is_available())): |         if not ( | ||||||
|  |             torch.cuda.is_available() | ||||||
|  |             or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) | ||||||
|  |         ): | ||||||
|             raise RuntimeError( |             raise RuntimeError( | ||||||
|                 "OffloadedCache can only be used with a GPU" + (" or XPU" if is_torch_greater_or_equal("2.7") else "") |                 "OffloadedCache can only be used with a GPU" | ||||||
|  |                 + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.original_device = [] |         self.original_device = [] | ||||||
|         self.prefetch_stream = None |         self.prefetch_stream = None | ||||||
|         self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal("2.7") else torch.cuda.Stream() |         self.prefetch_stream = ( | ||||||
|  |             torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream() | ||||||
|  |         ) | ||||||
|         self.beam_idx = None  # used to delay beam search operations |         self.beam_idx = None  # used to delay beam search operations | ||||||
|  |  | ||||||
|     def prefetch_layer(self, layer_idx: int): |     def prefetch_layer(self, layer_idx: int): | ||||||
|         "Starts prefetching the next layer cache" |         "Starts prefetching the next layer cache" | ||||||
|         if layer_idx < len(self): |         if layer_idx < len(self): | ||||||
|             with self.prefetch_stream if is_torch_greater_or_equal("2.7") else torch.cuda.stream(self.prefetch_stream): |             with self.prefetch_stream if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.stream( | ||||||
|  |                 self.prefetch_stream | ||||||
|  |             ): | ||||||
|                 # Prefetch next layer tensors to GPU |                 # Prefetch next layer tensors to GPU | ||||||
|                 device = self.original_device[layer_idx] |                 device = self.original_device[layer_idx] | ||||||
|                 self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) |                 self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) | ||||||
| @ -643,7 +651,7 @@ class OffloadedCache(DynamicCache): | |||||||
|         "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." |         "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." | ||||||
|         if layer_idx < len(self): |         if layer_idx < len(self): | ||||||
|             # Evict the previous layer if necessary |             # Evict the previous layer if necessary | ||||||
|             if is_torch_greater_or_equal("2.7"): |             if is_torch_greater_or_equal("2.7", accept_dev=True): | ||||||
|                 torch.accelerator.current_stream().synchronize() |                 torch.accelerator.current_stream().synchronize() | ||||||
|             else: |             else: | ||||||
|                 torch.cuda.current_stream().synchronize() |                 torch.cuda.current_stream().synchronize() | ||||||
|  | |||||||
| @ -1122,7 +1122,9 @@ class PretrainedConfig(PushToHubMixin): | |||||||
|         Returns the config that is meant to be used with text IO. On most models, it is the original config instance |         Returns the config that is meant to be used with text IO. On most models, it is the original config instance | ||||||
|         itself. On specific composite models, it is under a set of valid names. |         itself. On specific composite models, it is under a set of valid names. | ||||||
|  |  | ||||||
|         If `decoder` is set to `True`, then only search for decoder config names. |         Args: | ||||||
|  |             decoder (`Optional[bool]`, *optional*, defaults to `False`): | ||||||
|  |                 If set to `True`, then only search for decoder config names. | ||||||
|         """ |         """ | ||||||
|         decoder_possible_text_config_names = ("decoder", "generator", "text_config") |         decoder_possible_text_config_names = ("decoder", "generator", "text_config") | ||||||
|         encoder_possible_text_config_names = ("text_encoder",) |         encoder_possible_text_config_names = ("text_encoder",) | ||||||
| @ -1144,8 +1146,10 @@ class PretrainedConfig(PushToHubMixin): | |||||||
|                 "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." |                 "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." | ||||||
|             ) |             ) | ||||||
|         elif len(valid_text_config_names) == 1: |         elif len(valid_text_config_names) == 1: | ||||||
|             return getattr(self, valid_text_config_names[0]) |             config_to_return = getattr(self, valid_text_config_names[0]) | ||||||
|         return self |         else: | ||||||
|  |             config_to_return = self | ||||||
|  |         return config_to_return | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_configuration_file(configuration_files: List[str]) -> str: | def get_configuration_file(configuration_files: List[str]) -> str: | ||||||
|  | |||||||
| @ -3887,9 +3887,14 @@ class GenerationMixin: | |||||||
|         beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences]) |         beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences]) | ||||||
|         beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) |         beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :]) | ||||||
|  |  | ||||||
|         # Crop the static-shaped tensors to the actual size |         # Crop the static-shaped tensors to the actual size. | ||||||
|         sequences = sequences[:, :cur_len] |         # `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each | ||||||
|         beam_indices = beam_indices[:, : cur_len - decoder_prompt_len] |         # step. We can use it to detect the generated length, which may be != `cur_len`  (e.g. selected beam is from a | ||||||
|  |         # previous decoding iteration) | ||||||
|  |         max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max() | ||||||
|  |         output_length = decoder_prompt_len + max_generated_length | ||||||
|  |         sequences = sequences[:, :output_length] | ||||||
|  |         beam_indices = beam_indices[:, :max_generated_length] | ||||||
|  |  | ||||||
|         if return_dict_in_generate: |         if return_dict_in_generate: | ||||||
|             if not output_scores: |             if not output_scores: | ||||||
|  | |||||||
| @ -72,6 +72,8 @@ if is_vision_available(): | |||||||
|             PILImageResampling.BICUBIC: InterpolationMode.BICUBIC, |             PILImageResampling.BICUBIC: InterpolationMode.BICUBIC, | ||||||
|             PILImageResampling.LANCZOS: InterpolationMode.LANCZOS, |             PILImageResampling.LANCZOS: InterpolationMode.LANCZOS, | ||||||
|         } |         } | ||||||
|  |     else:  | ||||||
|  |         pil_torch_interpolation_mapping = {} | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  | |||||||
| @ -1,62 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2018 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert ALBERT checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from ...utils import logging |  | ||||||
| from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): |  | ||||||
|     # Initialise PyTorch model |  | ||||||
|     config = AlbertConfig.from_json_file(albert_config_file) |  | ||||||
|     print(f"Building PyTorch model from configuration: {config}") |  | ||||||
|     model = AlbertForPreTraining(config) |  | ||||||
|  |  | ||||||
|     # Load weights from tf checkpoint |  | ||||||
|     load_tf_weights_in_albert(model, config, tf_checkpoint_path) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model |  | ||||||
|     print(f"Save PyTorch model to {pytorch_dump_path}") |  | ||||||
|     torch.save(model.state_dict(), pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--albert_config_file", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "The config json file corresponding to the pre-trained ALBERT model. \n" |  | ||||||
|             "This specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,389 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert ALIGN checkpoints from the original repository.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import align |  | ||||||
| import numpy as np |  | ||||||
| import requests |  | ||||||
| import tensorflow as tf |  | ||||||
| import torch |  | ||||||
| from PIL import Image |  | ||||||
| from tokenizer import Tokenizer |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     AlignConfig, |  | ||||||
|     AlignModel, |  | ||||||
|     AlignProcessor, |  | ||||||
|     BertConfig, |  | ||||||
|     BertTokenizer, |  | ||||||
|     EfficientNetConfig, |  | ||||||
|     EfficientNetImageProcessor, |  | ||||||
| ) |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def preprocess(image): |  | ||||||
|     image = tf.image.resize(image, (346, 346)) |  | ||||||
|     image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289) |  | ||||||
|     return image |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_align_config(): |  | ||||||
|     vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7") |  | ||||||
|     vision_config.image_size = 289 |  | ||||||
|     vision_config.hidden_dim = 640 |  | ||||||
|     vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"} |  | ||||||
|     vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1} |  | ||||||
|     vision_config.depthwise_padding = [] |  | ||||||
|  |  | ||||||
|     text_config = BertConfig() |  | ||||||
|     config = AlignConfig.from_text_vision_configs( |  | ||||||
|         text_config=text_config, vision_config=vision_config, projection_dim=640 |  | ||||||
|     ) |  | ||||||
|     return config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_processor(): |  | ||||||
|     image_processor = EfficientNetImageProcessor( |  | ||||||
|         do_center_crop=True, |  | ||||||
|         rescale_factor=1 / 127.5, |  | ||||||
|         rescale_offset=True, |  | ||||||
|         do_normalize=False, |  | ||||||
|         include_top=False, |  | ||||||
|         resample=Image.BILINEAR, |  | ||||||
|     ) |  | ||||||
|     tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") |  | ||||||
|     tokenizer.model_max_length = 64 |  | ||||||
|     processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer) |  | ||||||
|     return processor |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # here we list all keys to be renamed (original name on the left, our name on the right) |  | ||||||
| def rename_keys(original_param_names): |  | ||||||
|     # EfficientNet image encoder |  | ||||||
|     block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")] |  | ||||||
|     block_names = list(set(block_names)) |  | ||||||
|     block_names = sorted(block_names) |  | ||||||
|     num_blocks = len(block_names) |  | ||||||
|     block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))} |  | ||||||
|  |  | ||||||
|     rename_keys = [] |  | ||||||
|     rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight")) |  | ||||||
|     rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight")) |  | ||||||
|     rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias")) |  | ||||||
|     rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean")) |  | ||||||
|     rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var")) |  | ||||||
|  |  | ||||||
|     for b in block_names: |  | ||||||
|         hf_b = block_name_mapping[b] |  | ||||||
|         rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias")) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias")) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var") |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias")) |  | ||||||
|         rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias")) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight")) |  | ||||||
|         rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias")) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var") |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     key_mapping = {} |  | ||||||
|     for item in rename_keys: |  | ||||||
|         if item[0] in original_param_names: |  | ||||||
|             key_mapping[item[0]] = "vision_model." + item[1] |  | ||||||
|  |  | ||||||
|     # BERT text encoder |  | ||||||
|     rename_keys = [] |  | ||||||
|     old = "tf_bert_model/bert" |  | ||||||
|     new = "text_model" |  | ||||||
|     for i in range(12): |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.query.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/query/bias:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.query.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.key.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/key/bias:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.key.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.value.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/self/value/bias:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.self.value.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.output.dense.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.output.dense.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.intermediate.dense.weight", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             ( |  | ||||||
|                 f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0", |  | ||||||
|                 f"{new}.encoder.layer.{i}.intermediate.dense.bias", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias") |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight")) |  | ||||||
|     rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias")) |  | ||||||
|  |  | ||||||
|     rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight")) |  | ||||||
|     rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias")) |  | ||||||
|     rename_keys.append(("dense/kernel:0", "text_projection.weight")) |  | ||||||
|     rename_keys.append(("dense/bias:0", "text_projection.bias")) |  | ||||||
|     rename_keys.append(("dense/bias:0", "text_projection.bias")) |  | ||||||
|     rename_keys.append(("temperature:0", "temperature")) |  | ||||||
|  |  | ||||||
|     for item in rename_keys: |  | ||||||
|         if item[0] in original_param_names: |  | ||||||
|             key_mapping[item[0]] = item[1] |  | ||||||
|     return key_mapping |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def replace_params(hf_params, tf_params, key_mapping): |  | ||||||
|     list(hf_params.keys()) |  | ||||||
|  |  | ||||||
|     for key, value in tf_params.items(): |  | ||||||
|         if key not in key_mapping: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         hf_key = key_mapping[key] |  | ||||||
|         if "_conv" in key and "kernel" in key: |  | ||||||
|             new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1) |  | ||||||
|         elif "embeddings" in key: |  | ||||||
|             new_hf_value = torch.from_numpy(value) |  | ||||||
|         elif "depthwise_kernel" in key: |  | ||||||
|             new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1) |  | ||||||
|         elif "kernel" in key: |  | ||||||
|             new_hf_value = torch.from_numpy(np.transpose(value)) |  | ||||||
|         elif "temperature" in key: |  | ||||||
|             new_hf_value = value |  | ||||||
|         elif "bn/gamma" or "bn/beta" in key: |  | ||||||
|             new_hf_value = torch.from_numpy(np.transpose(value)).squeeze() |  | ||||||
|         else: |  | ||||||
|             new_hf_value = torch.from_numpy(value) |  | ||||||
|  |  | ||||||
|         # Replace HF parameters with original TF model parameters |  | ||||||
|         hf_params[hf_key].copy_(new_hf_value) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our ALIGN structure. |  | ||||||
|     """ |  | ||||||
|     # Load original model |  | ||||||
|     seq_length = 64 |  | ||||||
|     tok = Tokenizer(seq_length) |  | ||||||
|     original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size()) |  | ||||||
|     original_model.compile() |  | ||||||
|     original_model.load_weights(checkpoint_path) |  | ||||||
|  |  | ||||||
|     tf_params = original_model.trainable_variables |  | ||||||
|     tf_non_train_params = original_model.non_trainable_variables |  | ||||||
|     tf_params = {param.name: param.numpy() for param in tf_params} |  | ||||||
|     for param in tf_non_train_params: |  | ||||||
|         tf_params[param.name] = param.numpy() |  | ||||||
|     tf_param_names = list(tf_params.keys()) |  | ||||||
|  |  | ||||||
|     # Load HuggingFace model |  | ||||||
|     config = get_align_config() |  | ||||||
|     hf_model = AlignModel(config).eval() |  | ||||||
|     hf_params = hf_model.state_dict() |  | ||||||
|  |  | ||||||
|     # Create src-to-dst parameter name mapping dictionary |  | ||||||
|     print("Converting parameters...") |  | ||||||
|     key_mapping = rename_keys(tf_param_names) |  | ||||||
|     replace_params(hf_params, tf_params, key_mapping) |  | ||||||
|  |  | ||||||
|     # Initialize processor |  | ||||||
|     processor = get_processor() |  | ||||||
|     inputs = processor( |  | ||||||
|         images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # HF model inference |  | ||||||
|     hf_model.eval() |  | ||||||
|     with torch.no_grad(): |  | ||||||
|         outputs = hf_model(**inputs) |  | ||||||
|  |  | ||||||
|     hf_image_features = outputs.image_embeds.detach().numpy() |  | ||||||
|     hf_text_features = outputs.text_embeds.detach().numpy() |  | ||||||
|  |  | ||||||
|     # Original model inference |  | ||||||
|     original_model.trainable = False |  | ||||||
|     tf_image_processor = EfficientNetImageProcessor( |  | ||||||
|         do_center_crop=True, |  | ||||||
|         do_rescale=False, |  | ||||||
|         do_normalize=False, |  | ||||||
|         include_top=False, |  | ||||||
|         resample=Image.BILINEAR, |  | ||||||
|     ) |  | ||||||
|     image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"] |  | ||||||
|     text = tok(tf.constant(["A picture of a cat"])) |  | ||||||
|  |  | ||||||
|     image_features = original_model.image_encoder(image, training=False) |  | ||||||
|     text_features = original_model.text_encoder(text, training=False) |  | ||||||
|  |  | ||||||
|     image_features = tf.nn.l2_normalize(image_features, axis=-1) |  | ||||||
|     text_features = tf.nn.l2_normalize(text_features, axis=-1) |  | ||||||
|  |  | ||||||
|     # Check whether original and HF model outputs match  -> np.allclose |  | ||||||
|     if not np.allclose(image_features, hf_image_features, atol=1e-3): |  | ||||||
|         raise ValueError("The predicted image features are not the same.") |  | ||||||
|     if not np.allclose(text_features, hf_text_features, atol=1e-3): |  | ||||||
|         raise ValueError("The predicted text features are not the same.") |  | ||||||
|     print("Model outputs match!") |  | ||||||
|  |  | ||||||
|     if save_model: |  | ||||||
|         # Create folder to save model |  | ||||||
|         if not os.path.isdir(pytorch_dump_folder_path): |  | ||||||
|             os.mkdir(pytorch_dump_folder_path) |  | ||||||
|         # Save converted model and image processor |  | ||||||
|         hf_model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         # Push model and image processor to hub |  | ||||||
|         print("Pushing converted ALIGN to the hub...") |  | ||||||
|         processor.push_to_hub("align-base") |  | ||||||
|         hf_model.push_to_hub("align-base") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_path", |  | ||||||
|         default="./weights/model-weights", |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the pretrained TF ALIGN checkpoint.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default="hf_model", |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the output PyTorch model directory.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--save_model", action="store_true", help="Save model to local") |  | ||||||
|     parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub") |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub) |  | ||||||
| @ -1,162 +0,0 @@ | |||||||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| import argparse |  | ||||||
| import glob |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import snapshot_download |  | ||||||
| from safetensors import safe_open |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     AddedToken, |  | ||||||
|     AriaForConditionalGeneration, |  | ||||||
|     AriaProcessor, |  | ||||||
|     AutoConfig, |  | ||||||
|     AutoTokenizer, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| EPILOG_TXT = """Example: |  | ||||||
|     python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria |  | ||||||
|  |  | ||||||
| Example for creating the old state dict file with Python: |  | ||||||
|  |  | ||||||
|     import torch |  | ||||||
|     from aria.model.language_model.aria_llama import AriaTextForCausalLM |  | ||||||
|  |  | ||||||
|     # load model |  | ||||||
|     kwargs = {"device_map": "auto", "torch_dtype": torch.float16} |  | ||||||
|     model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs) |  | ||||||
|  |  | ||||||
|     # load vision tower |  | ||||||
|     model.get_vision_tower().load_model() |  | ||||||
|  |  | ||||||
|     # Save state dict |  | ||||||
|     torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin") |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| KEYS_TO_MODIFY_MAPPING = { |  | ||||||
|     "vision_tower.vision_model": "vision_tower", |  | ||||||
|     "ln_ffn": "layer_norm", |  | ||||||
|     "ffn": "feed_forward", |  | ||||||
|     "ln_kv": "layer_norm_kv", |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_original_state_dict(model_id): |  | ||||||
|     directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) |  | ||||||
|  |  | ||||||
|     original_state_dict = {} |  | ||||||
|     for path in glob.glob(f"{directory_path}/*"): |  | ||||||
|         if path.endswith(".safetensors"): |  | ||||||
|             with safe_open(path, framework="pt", device="cpu") as f: |  | ||||||
|                 for key in f.keys(): |  | ||||||
|                     original_state_dict[key] = f.get_tensor(key) |  | ||||||
|  |  | ||||||
|     return original_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_state_dict_to_hf(state_dict): |  | ||||||
|     new_state_dict = {} |  | ||||||
|     for key, value in state_dict.items(): |  | ||||||
|         if key.endswith(".inv_freq"): |  | ||||||
|             continue |  | ||||||
|         for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |  | ||||||
|             if key_to_modify in key: |  | ||||||
|                 key = key.replace(key_to_modify, new_key) |  | ||||||
|  |  | ||||||
|         new_state_dict[key] = value |  | ||||||
|     new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,)) |  | ||||||
|     new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,)) |  | ||||||
|  |  | ||||||
|     return new_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): |  | ||||||
|     torch.set_default_dtype(torch.float16) |  | ||||||
|  |  | ||||||
|     tokenizer = AutoTokenizer.from_pretrained( |  | ||||||
|         text_model_id, |  | ||||||
|         extra_special_tokens={ |  | ||||||
|             "image_token": "<|img|>", |  | ||||||
|             "pad_token": "<pad>", |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True) |  | ||||||
|     tokenizer.add_special_tokens({"pad_token": "<pad>"}) |  | ||||||
|     tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" |  | ||||||
|  |  | ||||||
|     processor = AriaProcessor.from_pretrained( |  | ||||||
|         text_model_id, |  | ||||||
|         tokenizer=tokenizer, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     config = AutoConfig.from_pretrained(text_model_id) |  | ||||||
|     config.vision_config.hidden_size = 1152 |  | ||||||
|     config.vision_config.attention_heads = 16 |  | ||||||
|     config.pad_token_id = 2 |  | ||||||
|     config.image_token_index = 9 |  | ||||||
|     config.intermediate_size = config.moe_intermediate_size |  | ||||||
|     config.auto_map = { |  | ||||||
|         "AutoConfig": "modeling_aria.AriaConfig", |  | ||||||
|         "AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration", |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     with torch.device("meta"): |  | ||||||
|         model = AriaForConditionalGeneration(config) |  | ||||||
|  |  | ||||||
|     state_dict = load_original_state_dict(old_state_dict_id) |  | ||||||
|  |  | ||||||
|     state_dict = convert_state_dict_to_hf(state_dict) |  | ||||||
|     model.load_state_dict(state_dict, strict=False, assign=True) |  | ||||||
|  |  | ||||||
|     # print("Saving models") |  | ||||||
|     # model.save_pretrained("local_aria", safe_serialization=False) |  | ||||||
|     # processor.save_pretrained("local_aria") |  | ||||||
|     print("Pushing to hub") |  | ||||||
|     model.push_to_hub(output_hub_path, create_pr=True) |  | ||||||
|     processor.push_to_hub(output_hub_path, create_pr=True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): |  | ||||||
|     parser = argparse.ArgumentParser( |  | ||||||
|         epilog=EPILOG_TXT, |  | ||||||
|         formatter_class=argparse.RawDescriptionHelpFormatter, |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--text_model_id", |  | ||||||
|         default="rhymes-ai/Aria", |  | ||||||
|         help="Hub location of the text model", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--vision_model_id", |  | ||||||
|         default="rhymes-ai/Aria", |  | ||||||
|         help="Hub location of the vision model", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--output_hub_path", |  | ||||||
|         default="rhymes-ai/Aria", |  | ||||||
|         help="Location on the hub of the converted model", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--old_state_dict_id", |  | ||||||
|         default="rhymes-ai/Aria", |  | ||||||
|         help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     main() |  | ||||||
| @ -1,279 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| import torchaudio |  | ||||||
| from datasets import load_dataset |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
|  |  | ||||||
| from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_audio_spectrogram_transformer_config(model_name): |  | ||||||
|     config = ASTConfig() |  | ||||||
|  |  | ||||||
|     if "10-10" in model_name: |  | ||||||
|         pass |  | ||||||
|     elif "speech-commands" in model_name: |  | ||||||
|         config.max_length = 128 |  | ||||||
|     elif "12-12" in model_name: |  | ||||||
|         config.time_stride = 12 |  | ||||||
|         config.frequency_stride = 12 |  | ||||||
|     elif "14-14" in model_name: |  | ||||||
|         config.time_stride = 14 |  | ||||||
|         config.frequency_stride = 14 |  | ||||||
|     elif "16-16" in model_name: |  | ||||||
|         config.time_stride = 16 |  | ||||||
|         config.frequency_stride = 16 |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Model not supported") |  | ||||||
|  |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     if "speech-commands" in model_name: |  | ||||||
|         config.num_labels = 35 |  | ||||||
|         filename = "speech-commands-v2-id2label.json" |  | ||||||
|     else: |  | ||||||
|         config.num_labels = 527 |  | ||||||
|         filename = "audioset-id2label.json" |  | ||||||
|  |  | ||||||
|     id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|     config.id2label = id2label |  | ||||||
|     config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     return config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     if "module.v" in name: |  | ||||||
|         name = name.replace("module.v", "audio_spectrogram_transformer") |  | ||||||
|     if "cls_token" in name: |  | ||||||
|         name = name.replace("cls_token", "embeddings.cls_token") |  | ||||||
|     if "dist_token" in name: |  | ||||||
|         name = name.replace("dist_token", "embeddings.distillation_token") |  | ||||||
|     if "pos_embed" in name: |  | ||||||
|         name = name.replace("pos_embed", "embeddings.position_embeddings") |  | ||||||
|     if "patch_embed.proj" in name: |  | ||||||
|         name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") |  | ||||||
|     # transformer blocks |  | ||||||
|     if "blocks" in name: |  | ||||||
|         name = name.replace("blocks", "encoder.layer") |  | ||||||
|     if "attn.proj" in name: |  | ||||||
|         name = name.replace("attn.proj", "attention.output.dense") |  | ||||||
|     if "attn" in name: |  | ||||||
|         name = name.replace("attn", "attention.self") |  | ||||||
|     if "norm1" in name: |  | ||||||
|         name = name.replace("norm1", "layernorm_before") |  | ||||||
|     if "norm2" in name: |  | ||||||
|         name = name.replace("norm2", "layernorm_after") |  | ||||||
|     if "mlp.fc1" in name: |  | ||||||
|         name = name.replace("mlp.fc1", "intermediate.dense") |  | ||||||
|     if "mlp.fc2" in name: |  | ||||||
|         name = name.replace("mlp.fc2", "output.dense") |  | ||||||
|     # final layernorm |  | ||||||
|     if "audio_spectrogram_transformer.norm" in name: |  | ||||||
|         name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm") |  | ||||||
|     # classifier head |  | ||||||
|     if "module.mlp_head.0" in name: |  | ||||||
|         name = name.replace("module.mlp_head.0", "classifier.layernorm") |  | ||||||
|     if "module.mlp_head.1" in name: |  | ||||||
|         name = name.replace("module.mlp_head.1", "classifier.dense") |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_state_dict(orig_state_dict, config): |  | ||||||
|     for key in orig_state_dict.copy().keys(): |  | ||||||
|         val = orig_state_dict.pop(key) |  | ||||||
|  |  | ||||||
|         if "qkv" in key: |  | ||||||
|             key_split = key.split(".") |  | ||||||
|             layer_num = int(key_split[3]) |  | ||||||
|             dim = config.hidden_size |  | ||||||
|             if "weight" in key: |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight" |  | ||||||
|                 ] = val[:dim, :] |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight" |  | ||||||
|                 ] = val[dim : dim * 2, :] |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight" |  | ||||||
|                 ] = val[-dim:, :] |  | ||||||
|             else: |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias" |  | ||||||
|                 ] = val[:dim] |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias" |  | ||||||
|                 ] = val[dim : dim * 2] |  | ||||||
|                 orig_state_dict[ |  | ||||||
|                     f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias" |  | ||||||
|                 ] = val[-dim:] |  | ||||||
|         else: |  | ||||||
|             orig_state_dict[rename_key(key)] = val |  | ||||||
|  |  | ||||||
|     return orig_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def remove_keys(state_dict): |  | ||||||
|     ignore_keys = [ |  | ||||||
|         "module.v.head.weight", |  | ||||||
|         "module.v.head.bias", |  | ||||||
|         "module.v.head_dist.weight", |  | ||||||
|         "module.v.head_dist.bias", |  | ||||||
|     ] |  | ||||||
|     for k in ignore_keys: |  | ||||||
|         state_dict.pop(k, None) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure. |  | ||||||
|     """ |  | ||||||
|     config = get_audio_spectrogram_transformer_config(model_name) |  | ||||||
|  |  | ||||||
|     model_name_to_url = { |  | ||||||
|         "ast-finetuned-audioset-10-10-0.4593": ( |  | ||||||
|             "https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-10-10-0.450": ( |  | ||||||
|             "https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-10-10-0.448": ( |  | ||||||
|             "https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-10-10-0.448-v2": ( |  | ||||||
|             "https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-12-12-0.447": ( |  | ||||||
|             "https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-14-14-0.443": ( |  | ||||||
|             "https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-audioset-16-16-0.442": ( |  | ||||||
|             "https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|         "ast-finetuned-speech-commands-v2": ( |  | ||||||
|             "https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1" |  | ||||||
|         ), |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     # load original state_dict |  | ||||||
|     checkpoint_url = model_name_to_url[model_name] |  | ||||||
|     state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") |  | ||||||
|     # remove some keys |  | ||||||
|     remove_keys(state_dict) |  | ||||||
|     # rename some keys |  | ||||||
|     new_state_dict = convert_state_dict(state_dict, config) |  | ||||||
|  |  | ||||||
|     # load 🤗 model |  | ||||||
|     model = ASTForAudioClassification(config) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     model.load_state_dict(new_state_dict) |  | ||||||
|  |  | ||||||
|     # verify outputs on dummy input |  | ||||||
|     # source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62 |  | ||||||
|     mean = -4.2677393 if "speech-commands" not in model_name else -6.845978 |  | ||||||
|     std = 4.5689974 if "speech-commands" not in model_name else 5.5654526 |  | ||||||
|     max_length = 1024 if "speech-commands" not in model_name else 128 |  | ||||||
|     feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length) |  | ||||||
|  |  | ||||||
|     if "speech-commands" in model_name: |  | ||||||
|         # TODO: Convert dataset to Parquet |  | ||||||
|         dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True) |  | ||||||
|         waveform = dataset[0]["audio"]["array"] |  | ||||||
|     else: |  | ||||||
|         filepath = hf_hub_download( |  | ||||||
|             repo_id="nielsr/audio-spectogram-transformer-checkpoint", |  | ||||||
|             filename="sample_audio.flac", |  | ||||||
|             repo_type="dataset", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         waveform, _ = torchaudio.load(filepath) |  | ||||||
|         waveform = waveform.squeeze().numpy() |  | ||||||
|  |  | ||||||
|     inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt") |  | ||||||
|  |  | ||||||
|     # forward pass |  | ||||||
|     outputs = model(**inputs) |  | ||||||
|     logits = outputs.logits |  | ||||||
|  |  | ||||||
|     if model_name == "ast-finetuned-audioset-10-10-0.4593": |  | ||||||
|         expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-10-10-0.450": |  | ||||||
|         expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-10-10-0.448": |  | ||||||
|         expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-10-10-0.448-v2": |  | ||||||
|         expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-12-12-0.447": |  | ||||||
|         expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-14-14-0.443": |  | ||||||
|         expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413]) |  | ||||||
|     elif model_name == "ast-finetuned-audioset-16-16-0.442": |  | ||||||
|         expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470]) |  | ||||||
|     elif model_name == "ast-finetuned-speech-commands-v2": |  | ||||||
|         expected_slice = torch.tensor([6.1589, -8.0566, -8.7984]) |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Unknown model name") |  | ||||||
|     if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4): |  | ||||||
|         raise ValueError("Logits don't match") |  | ||||||
|     print("Looks ok!") |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|         print(f"Saving model {model_name} to {pytorch_dump_folder_path}") |  | ||||||
|         model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         print(f"Saving feature extractor to {pytorch_dump_folder_path}") |  | ||||||
|         feature_extractor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         print("Pushing model and feature extractor to the hub...") |  | ||||||
|         model.push_to_hub(f"MIT/{model_name}") |  | ||||||
|         feature_extractor.push_to_hub(f"MIT/{model_name}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="ast-finetuned-audioset-10-10-0.4593", |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the Audio Spectrogram Transformer model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) |  | ||||||
| @ -493,7 +493,7 @@ class AutoImageProcessor: | |||||||
|                 image_processor_auto_map = config.auto_map["AutoImageProcessor"] |                 image_processor_auto_map = config.auto_map["AutoImageProcessor"] | ||||||
|  |  | ||||||
|         image_processor_class = None |         image_processor_class = None | ||||||
|         # TODO: @yoni, change logic in v4.50 (when use_fast set to True by default) |         # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default) | ||||||
|         if image_processor_type is not None: |         if image_processor_type is not None: | ||||||
|             # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor. |             # if use_fast is not set and the processor was saved with a fast processor, we use it, otherwise we use the slow processor. | ||||||
|             if use_fast is None: |             if use_fast is None: | ||||||
| @ -501,7 +501,7 @@ class AutoImageProcessor: | |||||||
|                 if not use_fast: |                 if not use_fast: | ||||||
|                     logger.warning_once( |                     logger.warning_once( | ||||||
|                         "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. " |                         "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. " | ||||||
|                         "`use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. " |                         "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. " | ||||||
|                         "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`." |                         "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`." | ||||||
|                     ) |                     ) | ||||||
|             # Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version. |             # Update class name to reflect the use_fast option. If class is not found, we fall back to the slow version. | ||||||
|  | |||||||
| @ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( | |||||||
|         ("fuyu", "FuyuForCausalLM"), |         ("fuyu", "FuyuForCausalLM"), | ||||||
|         ("gemma", "GemmaForCausalLM"), |         ("gemma", "GemmaForCausalLM"), | ||||||
|         ("gemma2", "Gemma2ForCausalLM"), |         ("gemma2", "Gemma2ForCausalLM"), | ||||||
|         ("gemma3", "Gemma3ForCausalLM"), |         ("gemma3", "Gemma3ForConditionalGeneration"), | ||||||
|         ("gemma3_text", "Gemma3ForCausalLM"), |         ("gemma3_text", "Gemma3ForCausalLM"), | ||||||
|         ("git", "GitForCausalLM"), |         ("git", "GitForCausalLM"), | ||||||
|         ("glm", "GlmForCausalLM"), |         ("glm", "GlmForCausalLM"), | ||||||
| @ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass): | |||||||
|         Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own |         Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own | ||||||
|         config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM. |         config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM. | ||||||
|         """ |         """ | ||||||
|         return config.get_text_config(decoder=True) |         possible_text_config_names = ("decoder", "generator", "text_config") | ||||||
|  |         text_config_names = [] | ||||||
|  |         for text_config_name in possible_text_config_names: | ||||||
|  |             if hasattr(config, text_config_name): | ||||||
|  |                 text_config_names += [text_config_name] | ||||||
|  |  | ||||||
|  |         text_config = config.get_text_config(decoder=True) | ||||||
|  |         if text_config_names and type(text_config) in cls._model_mapping.keys(): | ||||||
|  |             warnings.warn( | ||||||
|  |                 "Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. " | ||||||
|  |                 "`AutoModelForCausalLM` will be used to load only the text-to-text generation module.", | ||||||
|  |                 FutureWarning, | ||||||
|  |             ) | ||||||
|  |         return config | ||||||
|  |  | ||||||
|  |  | ||||||
| AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") | ||||||
|  | |||||||
| @ -1,273 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| import os |  | ||||||
| import re |  | ||||||
| from os import path |  | ||||||
| from typing import Dict, Union |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import split_torch_state_dict_into_shards |  | ||||||
| from safetensors.torch import save_file |  | ||||||
|  |  | ||||||
| from transformers import AutoTokenizer |  | ||||||
| from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME |  | ||||||
|  |  | ||||||
| from .configuration_bamba import BambaConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: |  | ||||||
|     state_dict = {} |  | ||||||
|  |  | ||||||
|     for orig_k, param in original_sd.items(): |  | ||||||
|         k = orig_k.replace("backbone", "model") |  | ||||||
|  |  | ||||||
|         # for embeddings |  | ||||||
|         k = k.replace("embedding", "embed_tokens") |  | ||||||
|  |  | ||||||
|         # for mixer |  | ||||||
|         k = k.replace("mixer", "mamba") |  | ||||||
|  |  | ||||||
|         # for final layernorm |  | ||||||
|         k = k.replace("norm_f", "final_layernorm") |  | ||||||
|  |  | ||||||
|         # for block layernorm |  | ||||||
|         k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) |  | ||||||
|         k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) |  | ||||||
|  |  | ||||||
|         # for mlp |  | ||||||
|         k = k.replace("mlp.fc2", "feed_forward.down_proj") |  | ||||||
|  |  | ||||||
|         if "mlp.fc1" in k: |  | ||||||
|             param, param2 = torch.chunk(param, 2, dim=0) |  | ||||||
|             k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") |  | ||||||
|             state_dict[k2] = param2 |  | ||||||
|             k = k.replace("mlp.fc1", "feed_forward.up_proj") |  | ||||||
|  |  | ||||||
|         if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( |  | ||||||
|             "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd |  | ||||||
|         ): |  | ||||||
|             # then this must be a mamba |  | ||||||
|             pass |  | ||||||
|         else: |  | ||||||
|             # for attn |  | ||||||
|             # - because mixer was replaced to mamba above |  | ||||||
|             k = k.replace("mamba.out_proj", "self_attn.o_proj") |  | ||||||
|             if "mamba.in_proj" in k: |  | ||||||
|                 m, n = param.shape |  | ||||||
|                 d = (m - n) // 2 |  | ||||||
|                 param, param2, param3 = torch.split(param, [n, d, d], dim=0) |  | ||||||
|                 k2 = k.replace("mamba.in_proj", "self_attn.k_proj") |  | ||||||
|                 state_dict[k2] = param2 |  | ||||||
|                 k2 = k.replace("mamba.in_proj", "self_attn.v_proj") |  | ||||||
|                 state_dict[k2] = param3 |  | ||||||
|                 k = k.replace("mamba.in_proj", "self_attn.q_proj") |  | ||||||
|  |  | ||||||
|         state_dict[k] = param |  | ||||||
|  |  | ||||||
|     return state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py |  | ||||||
| def convert_ssm_config_to_hf_config( |  | ||||||
|     config_ssm: Dict, |  | ||||||
|     **kwargs, |  | ||||||
| ) -> BambaConfig: |  | ||||||
|     """Convert a config from mamba_ssm to a BambaConfig from here.""" |  | ||||||
|     hf_config: BambaConfig = BambaConfig(**kwargs) |  | ||||||
|  |  | ||||||
|     hf_config.architectures = ["BambaForCausalLM"] |  | ||||||
|  |  | ||||||
|     # Set important values from config and recalculate other resulting entries |  | ||||||
|     hf_config.hidden_size = config_ssm["d_model"] |  | ||||||
|     hf_config.intermediate_size = config_ssm["d_intermediate"] |  | ||||||
|     hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head |  | ||||||
|     hf_config.num_hidden_layers = config_ssm["n_layer"] |  | ||||||
|     hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] |  | ||||||
|  |  | ||||||
|     # currently this script assumes config_ssm belongs to v2 |  | ||||||
|     if config_ssm["ssm_cfg"].get("layer") != "Mamba2": |  | ||||||
|         raise ValueError("Conversion script only supports Mamba2") |  | ||||||
|  |  | ||||||
|     # Set attention values |  | ||||||
|     attn_cfg = config_ssm.get("attn_cfg") |  | ||||||
|     if attn_cfg: |  | ||||||
|         assert attn_cfg["causal"], "Only support non-causal attention." |  | ||||||
|         assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias." |  | ||||||
|         assert not attn_cfg["out_proj_bias"], "Only support no out bias." |  | ||||||
|         hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"] |  | ||||||
|         hf_config.num_attention_heads = attn_cfg["num_heads"] |  | ||||||
|         hf_config.num_key_value_heads = attn_cfg["num_heads_kv"] |  | ||||||
|  |  | ||||||
|     attention_layer_indices = config_ssm.get("attn_layer_idx") |  | ||||||
|     if attention_layer_indices: |  | ||||||
|         hf_config.attn_layer_indices = attention_layer_indices |  | ||||||
|  |  | ||||||
|     # Padded vocab size, mostly of 16 but 32 is also very common in different models |  | ||||||
|     vocab_size = config_ssm["vocab_size"] |  | ||||||
|     pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] |  | ||||||
|     if (vocab_size % pad_vocab_size_multiple) != 0: |  | ||||||
|         vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) |  | ||||||
|     hf_config.vocab_size = vocab_size |  | ||||||
|  |  | ||||||
|     return hf_config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def save_single_safetensor( |  | ||||||
|     state_dict: Dict, |  | ||||||
|     save_directory: str, |  | ||||||
|     metadata: Dict, |  | ||||||
| ): |  | ||||||
|     save_file( |  | ||||||
|         state_dict, |  | ||||||
|         os.path.join(save_directory, SAFE_WEIGHTS_NAME), |  | ||||||
|         metadata, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def save_sharded_safetensors( |  | ||||||
|     state_dict: Dict, |  | ||||||
|     save_directory: str, |  | ||||||
|     metadata: Dict, |  | ||||||
|     max_shard_size: Union[int, str] = "5GB", |  | ||||||
| ): |  | ||||||
|     filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace( |  | ||||||
|         ".safetensors", "{suffix}.safetensors" |  | ||||||
|     ) |  | ||||||
|     state_dict_split = split_torch_state_dict_into_shards( |  | ||||||
|         state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size |  | ||||||
|     ) |  | ||||||
|     index = { |  | ||||||
|         "metadata": state_dict_split.metadata, |  | ||||||
|         "weight_map": state_dict_split.tensor_to_filename, |  | ||||||
|     } |  | ||||||
|     # Save the index |  | ||||||
|     with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: |  | ||||||
|         content = json.dumps(index, indent=2, sort_keys=True) + "\n" |  | ||||||
|         f.write(content) |  | ||||||
|  |  | ||||||
|     filename_to_tensors = state_dict_split.filename_to_tensors.items() |  | ||||||
|     for shard_file, tensors in filename_to_tensors: |  | ||||||
|         shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors} |  | ||||||
|         save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py |  | ||||||
| def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( |  | ||||||
|     mamba_ssm_checkpoint_path: str, |  | ||||||
|     precision: str, |  | ||||||
|     output_dir: str, |  | ||||||
|     tokenizer_path: str = None, |  | ||||||
|     save_model: Union[bool, str] = True, |  | ||||||
| ) -> None: |  | ||||||
|     # load tokenizer if provided, this will be used to set the |  | ||||||
|     # token_ids in the config file |  | ||||||
|     token_ids = {} |  | ||||||
|     if tokenizer_path: |  | ||||||
|         tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |  | ||||||
|         for key in [ |  | ||||||
|             "bos_token_id", |  | ||||||
|             "eos_token_id", |  | ||||||
|             "pad_token_id", |  | ||||||
|         ]: |  | ||||||
|             id = getattr(tokenizer, key, None) |  | ||||||
|             if id: |  | ||||||
|                 token_ids[key] = id |  | ||||||
|  |  | ||||||
|     # there are some configs unsettable by mamba_ssn config, so |  | ||||||
|     # if there are changes from the defaults, have to pass them into |  | ||||||
|     # the function |  | ||||||
|     unsettables = { |  | ||||||
|         "mamba_d_head": 64, |  | ||||||
|         "mamba_d_state": 128, |  | ||||||
|         "mamba_n_groups": 1, |  | ||||||
|         "rms_norm_eps": 1e-5, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     # Load and save config based on name |  | ||||||
|     config_path = path.join(mamba_ssm_checkpoint_path, "config.json") |  | ||||||
|     with open(config_path, "r", encoding="utf-8") as json_file: |  | ||||||
|         config = json.load(json_file) |  | ||||||
|  |  | ||||||
|     # convert the config |  | ||||||
|     hf_config = convert_ssm_config_to_hf_config( |  | ||||||
|         config_ssm=config, |  | ||||||
|         **token_ids, |  | ||||||
|         **unsettables, |  | ||||||
|     ) |  | ||||||
|     hf_config.save_pretrained(output_dir) |  | ||||||
|  |  | ||||||
|     # Load state dict of the original model and transfer to hf model |  | ||||||
|     state_dict = torch.load( |  | ||||||
|         path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"), |  | ||||||
|         map_location="cpu", |  | ||||||
|         weights_only=True, |  | ||||||
|     ) |  | ||||||
|     # FIXME: allow other parameters to pass in |  | ||||||
|     state_dict = convert_state_dict_from_mamba_ssm(state_dict) |  | ||||||
|  |  | ||||||
|     # Save new model to pytorch_dump_path |  | ||||||
|     dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) |  | ||||||
|  |  | ||||||
|     save_file_fn = None |  | ||||||
|     if isinstance(save_model, bool) and save_model: |  | ||||||
|         save_file_fn = save_single_safetensor |  | ||||||
|     elif isinstance(save_model, str) and save_model == "sharded": |  | ||||||
|         save_file_fn = save_sharded_safetensors |  | ||||||
|  |  | ||||||
|     if save_file_fn: |  | ||||||
|         save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"}) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "-i", |  | ||||||
|         "--mamba_ssm_checkpoint_directory", |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "-p", |  | ||||||
|         "--precision", |  | ||||||
|         type=str, |  | ||||||
|         default="fp16", |  | ||||||
|         const="fp16", |  | ||||||
|         required=True, |  | ||||||
|         choices=("fp32", "fp16", "bf16"), |  | ||||||
|         help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "-t", |  | ||||||
|         "--tokenizer_model_path", |  | ||||||
|         type=str, |  | ||||||
|         default=None, |  | ||||||
|         required=False, |  | ||||||
|         help="Path to a the tokenizer file.", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_mamba_ssm_checkpoint_file_to_huggingface_model_file( |  | ||||||
|         args.mamba2_checkpoint_directory, |  | ||||||
|         args.precision, |  | ||||||
|         args.output_dir, |  | ||||||
|     ) |  | ||||||
| @ -1,263 +0,0 @@ | |||||||
| """Convert Bark checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from bark.generation import _load_model as _bark_load_model |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
|  |  | ||||||
| from transformers import EncodecConfig, EncodecModel, set_seed |  | ||||||
| from transformers.models.bark.configuration_bark import ( |  | ||||||
|     BarkCoarseConfig, |  | ||||||
|     BarkConfig, |  | ||||||
|     BarkFineConfig, |  | ||||||
|     BarkSemanticConfig, |  | ||||||
| ) |  | ||||||
| from transformers.models.bark.generation_configuration_bark import ( |  | ||||||
|     BarkCoarseGenerationConfig, |  | ||||||
|     BarkFineGenerationConfig, |  | ||||||
|     BarkGenerationConfig, |  | ||||||
|     BarkSemanticGenerationConfig, |  | ||||||
| ) |  | ||||||
| from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| set_seed(770) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| new_layer_name_dict = { |  | ||||||
|     "c_attn": "att_proj", |  | ||||||
|     "c_proj": "out_proj", |  | ||||||
|     "c_fc": "in_proj", |  | ||||||
|     "transformer.": "", |  | ||||||
|     "h.": "layers.", |  | ||||||
|     "ln_1": "layernorm_1", |  | ||||||
|     "ln_2": "layernorm_2", |  | ||||||
|     "ln_f": "layernorm_final", |  | ||||||
|     "wpe": "position_embeds_layer", |  | ||||||
|     "wte": "input_embeds_layer", |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| REMOTE_MODEL_PATHS = { |  | ||||||
|     "text_small": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "text.pt", |  | ||||||
|     }, |  | ||||||
|     "coarse_small": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "coarse.pt", |  | ||||||
|     }, |  | ||||||
|     "fine_small": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "fine.pt", |  | ||||||
|     }, |  | ||||||
|     "text": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "text_2.pt", |  | ||||||
|     }, |  | ||||||
|     "coarse": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "coarse_2.pt", |  | ||||||
|     }, |  | ||||||
|     "fine": { |  | ||||||
|         "repo_id": "suno/bark", |  | ||||||
|         "file_name": "fine_2.pt", |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| CUR_PATH = os.path.dirname(os.path.abspath(__file__)) |  | ||||||
| default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") |  | ||||||
| CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_ckpt_path(model_type, use_small=False): |  | ||||||
|     key = model_type |  | ||||||
|     if use_small: |  | ||||||
|         key += "_small" |  | ||||||
|     return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _download(from_hf_path, file_name): |  | ||||||
|     os.makedirs(CACHE_DIR, exist_ok=True) |  | ||||||
|     hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _load_model(ckpt_path, device, use_small=False, model_type="text"): |  | ||||||
|     if model_type == "text": |  | ||||||
|         ModelClass = BarkSemanticModel |  | ||||||
|         ConfigClass = BarkSemanticConfig |  | ||||||
|         GenerationConfigClass = BarkSemanticGenerationConfig |  | ||||||
|     elif model_type == "coarse": |  | ||||||
|         ModelClass = BarkCoarseModel |  | ||||||
|         ConfigClass = BarkCoarseConfig |  | ||||||
|         GenerationConfigClass = BarkCoarseGenerationConfig |  | ||||||
|     elif model_type == "fine": |  | ||||||
|         ModelClass = BarkFineModel |  | ||||||
|         ConfigClass = BarkFineConfig |  | ||||||
|         GenerationConfigClass = BarkFineGenerationConfig |  | ||||||
|     else: |  | ||||||
|         raise NotImplementedError() |  | ||||||
|     model_key = f"{model_type}_small" if use_small else model_type |  | ||||||
|     model_info = REMOTE_MODEL_PATHS[model_key] |  | ||||||
|     if not os.path.exists(ckpt_path): |  | ||||||
|         logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") |  | ||||||
|         _download(model_info["repo_id"], model_info["file_name"]) |  | ||||||
|     checkpoint = torch.load(ckpt_path, map_location=device) |  | ||||||
|     # this is a hack |  | ||||||
|     model_args = checkpoint["model_args"] |  | ||||||
|     if "input_vocab_size" not in model_args: |  | ||||||
|         model_args["input_vocab_size"] = model_args["vocab_size"] |  | ||||||
|         model_args["output_vocab_size"] = model_args["vocab_size"] |  | ||||||
|         del model_args["vocab_size"] |  | ||||||
|  |  | ||||||
|     # convert Bark model arguments to HF Bark model arguments |  | ||||||
|     model_args["num_heads"] = model_args.pop("n_head") |  | ||||||
|     model_args["hidden_size"] = model_args.pop("n_embd") |  | ||||||
|     model_args["num_layers"] = model_args.pop("n_layer") |  | ||||||
|  |  | ||||||
|     model_config = ConfigClass(**checkpoint["model_args"]) |  | ||||||
|     model = ModelClass(config=model_config) |  | ||||||
|     model_generation_config = GenerationConfigClass() |  | ||||||
|  |  | ||||||
|     model.generation_config = model_generation_config |  | ||||||
|     state_dict = checkpoint["model"] |  | ||||||
|     # fixup checkpoint |  | ||||||
|     unwanted_prefix = "_orig_mod." |  | ||||||
|     for k, v in list(state_dict.items()): |  | ||||||
|         if k.startswith(unwanted_prefix): |  | ||||||
|             # replace part of the key with corresponding layer name in HF implementation |  | ||||||
|             new_k = k[len(unwanted_prefix) :] |  | ||||||
|             for old_layer_name in new_layer_name_dict: |  | ||||||
|                 new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name]) |  | ||||||
|  |  | ||||||
|             state_dict[new_k] = state_dict.pop(k) |  | ||||||
|  |  | ||||||
|     extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) |  | ||||||
|     extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")} |  | ||||||
|     missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) |  | ||||||
|     missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")} |  | ||||||
|     if len(extra_keys) != 0: |  | ||||||
|         raise ValueError(f"extra keys found: {extra_keys}") |  | ||||||
|     if len(missing_keys) != 0: |  | ||||||
|         raise ValueError(f"missing keys: {missing_keys}") |  | ||||||
|     model.load_state_dict(state_dict, strict=False) |  | ||||||
|     n_params = model.num_parameters(exclude_embeddings=True) |  | ||||||
|     val_loss = checkpoint["best_val_loss"].item() |  | ||||||
|     logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") |  | ||||||
|     model.eval() |  | ||||||
|     model.to(device) |  | ||||||
|     del checkpoint, state_dict |  | ||||||
|  |  | ||||||
|     return model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"): |  | ||||||
|     if model_type not in ("text", "coarse", "fine"): |  | ||||||
|         raise NotImplementedError() |  | ||||||
|  |  | ||||||
|     device = "cpu"  # do conversion on cpu |  | ||||||
|  |  | ||||||
|     ckpt_path = _get_ckpt_path(model_type, use_small=use_small) |  | ||||||
|     model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small) |  | ||||||
|  |  | ||||||
|     # load bark initial model |  | ||||||
|     bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small) |  | ||||||
|  |  | ||||||
|     if model_type == "text": |  | ||||||
|         bark_model = bark_model["model"] |  | ||||||
|  |  | ||||||
|     if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params(): |  | ||||||
|         raise ValueError("initial and new models don't have the same number of parameters") |  | ||||||
|  |  | ||||||
|     # check if same output as the bark model |  | ||||||
|     batch_size = 5 |  | ||||||
|     sequence_length = 10 |  | ||||||
|  |  | ||||||
|     if model_type in ["text", "coarse"]: |  | ||||||
|         vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int) |  | ||||||
|         output_old_model = bark_model(vec)[0] |  | ||||||
|  |  | ||||||
|         output_new_model_total = model(vec) |  | ||||||
|  |  | ||||||
|         # take last logits |  | ||||||
|         output_new_model = output_new_model_total.logits[:, [-1], :] |  | ||||||
|  |  | ||||||
|     else: |  | ||||||
|         prediction_codeboook_channel = 3 |  | ||||||
|         n_codes_total = 8 |  | ||||||
|         vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int) |  | ||||||
|  |  | ||||||
|         output_new_model_total = model(prediction_codeboook_channel, vec) |  | ||||||
|         output_old_model = bark_model(prediction_codeboook_channel, vec) |  | ||||||
|  |  | ||||||
|         output_new_model = output_new_model_total.logits |  | ||||||
|  |  | ||||||
|     # output difference should come from the difference of self-attention implementation design |  | ||||||
|     if output_new_model.shape != output_old_model.shape: |  | ||||||
|         raise ValueError("initial and new outputs don't have the same shape") |  | ||||||
|     if (output_new_model - output_old_model).abs().max().item() > 1e-3: |  | ||||||
|         raise ValueError("initial and new outputs are not equal") |  | ||||||
|  |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_whole_bark_model( |  | ||||||
|     semantic_path, |  | ||||||
|     coarse_path, |  | ||||||
|     fine_path, |  | ||||||
|     append_text, |  | ||||||
|     hub_path, |  | ||||||
|     folder_path, |  | ||||||
| ): |  | ||||||
|     pytorch_dump_folder_path = os.path.join(folder_path, append_text) |  | ||||||
|  |  | ||||||
|     semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json")) |  | ||||||
|     coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json")) |  | ||||||
|     fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json")) |  | ||||||
|     codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz") |  | ||||||
|  |  | ||||||
|     semantic = BarkSemanticModel.from_pretrained(semantic_path) |  | ||||||
|     coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path) |  | ||||||
|     fineAcoustic = BarkFineModel.from_pretrained(fine_path) |  | ||||||
|     codec = EncodecModel.from_pretrained("facebook/encodec_24khz") |  | ||||||
|  |  | ||||||
|     bark_config = BarkConfig.from_sub_model_configs( |  | ||||||
|         semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     bark_generation_config = BarkGenerationConfig.from_sub_model_configs( |  | ||||||
|         semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     bark = BarkModel(bark_config) |  | ||||||
|  |  | ||||||
|     bark.semantic = semantic |  | ||||||
|     bark.coarse_acoustics = coarseAcoustic |  | ||||||
|     bark.fine_acoustics = fineAcoustic |  | ||||||
|     bark.codec_model = codec |  | ||||||
|  |  | ||||||
|     bark.generation_config = bark_generation_config |  | ||||||
|  |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|  |  | ||||||
|     parser.add_argument("model_type", type=str, help="text, coarse or fine.") |  | ||||||
|     parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.") |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small) |  | ||||||
| @ -1,156 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2020 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BART checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import fairseq |  | ||||||
| import torch |  | ||||||
| from packaging import version |  | ||||||
| from torch import nn |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     BartConfig, |  | ||||||
|     BartForConditionalGeneration, |  | ||||||
|     BartForSequenceClassification, |  | ||||||
|     BartModel, |  | ||||||
|     BartTokenizer, |  | ||||||
| ) |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] |  | ||||||
| extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification} |  | ||||||
| if version.parse(fairseq.__version__) < version.parse("0.9.0"): |  | ||||||
|     raise Exception("requires fairseq >= 0.9.0") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| SAMPLE_TEXT = " Hello world! cécé herlolip" |  | ||||||
|  |  | ||||||
| mnli_rename_keys = [ |  | ||||||
|     ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"), |  | ||||||
|     ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"), |  | ||||||
|     ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"), |  | ||||||
|     ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"), |  | ||||||
| ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def remove_ignore_keys_(state_dict): |  | ||||||
|     ignore_keys = [ |  | ||||||
|         "encoder.version", |  | ||||||
|         "decoder.version", |  | ||||||
|         "model.encoder.version", |  | ||||||
|         "model.decoder.version", |  | ||||||
|         "_float_tensor", |  | ||||||
|     ] |  | ||||||
|     for k in ignore_keys: |  | ||||||
|         state_dict.pop(k, None) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(dct, old, new): |  | ||||||
|     val = dct.pop(old) |  | ||||||
|     dct[new] = val |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_xsum_checkpoint(checkpoint_path): |  | ||||||
|     """Checkpoint path should end in model.pt""" |  | ||||||
|     sd = torch.load(checkpoint_path, map_location="cpu") |  | ||||||
|     hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval() |  | ||||||
|     hub_interface.model.load_state_dict(sd["model"]) |  | ||||||
|     return hub_interface |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def make_linear_from_emb(emb): |  | ||||||
|     vocab_size, emb_size = emb.weight.shape |  | ||||||
|     lin_layer = nn.Linear(vocab_size, emb_size, bias=False) |  | ||||||
|     lin_layer.weight.data = emb.weight.data |  | ||||||
|     return lin_layer |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our BERT structure. |  | ||||||
|     """ |  | ||||||
|     if not os.path.exists(checkpoint_path): |  | ||||||
|         bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval() |  | ||||||
|     else: |  | ||||||
|         bart = load_xsum_checkpoint(checkpoint_path) |  | ||||||
|  |  | ||||||
|     bart.model.upgrade_state_dict(bart.model.state_dict()) |  | ||||||
|     if hf_checkpoint_name is None: |  | ||||||
|         hf_checkpoint_name = checkpoint_path.replace(".", "-") |  | ||||||
|     config = BartConfig.from_pretrained(hf_checkpoint_name) |  | ||||||
|     tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0) |  | ||||||
|     tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0) |  | ||||||
|     if not torch.eq(tokens, tokens2).all(): |  | ||||||
|         raise ValueError( |  | ||||||
|             f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if checkpoint_path == "bart.large.mnli": |  | ||||||
|         state_dict = bart.state_dict() |  | ||||||
|         remove_ignore_keys_(state_dict) |  | ||||||
|         state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"] |  | ||||||
|         for src, dest in mnli_rename_keys: |  | ||||||
|             rename_key(state_dict, src, dest) |  | ||||||
|         model = BartForSequenceClassification(config).eval() |  | ||||||
|         model.load_state_dict(state_dict) |  | ||||||
|         fairseq_output = bart.predict("mnli", tokens, return_logits=True) |  | ||||||
|         new_model_outputs = model(tokens)[0]  # logits |  | ||||||
|     else:  # no classification heads to worry about |  | ||||||
|         state_dict = bart.model.state_dict() |  | ||||||
|         remove_ignore_keys_(state_dict) |  | ||||||
|         state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"] |  | ||||||
|         fairseq_output = bart.extract_features(tokens) |  | ||||||
|         if hf_checkpoint_name == "facebook/bart-large": |  | ||||||
|             model = BartModel(config).eval() |  | ||||||
|             model.load_state_dict(state_dict) |  | ||||||
|             new_model_outputs = model(tokens).model[0] |  | ||||||
|         else: |  | ||||||
|             model = BartForConditionalGeneration(config).eval()  # an existing summarization ckpt |  | ||||||
|             model.model.load_state_dict(state_dict) |  | ||||||
|             if hasattr(model, "lm_head"): |  | ||||||
|                 model.lm_head = make_linear_from_emb(model.model.shared) |  | ||||||
|             new_model_outputs = model.model(tokens)[0] |  | ||||||
|  |  | ||||||
|     # Check results |  | ||||||
|     if fairseq_output.shape != new_model_outputs.shape: |  | ||||||
|         raise ValueError( |  | ||||||
|             f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}" |  | ||||||
|         ) |  | ||||||
|     if (fairseq_output != new_model_outputs).any().item(): |  | ||||||
|         raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`") |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum" |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config) |  | ||||||
| @ -1,373 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BEiT checkpoints from the unilm repository.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from datasets import load_dataset |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     BeitConfig, |  | ||||||
|     BeitForImageClassification, |  | ||||||
|     BeitForMaskedImageModeling, |  | ||||||
|     BeitForSemanticSegmentation, |  | ||||||
|     BeitImageProcessor, |  | ||||||
| ) |  | ||||||
| from transformers.image_utils import PILImageResampling |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # here we list all keys to be renamed (original name on the left, our name on the right) |  | ||||||
| def create_rename_keys(config, has_lm_head=False, is_semantic=False): |  | ||||||
|     prefix = "backbone." if is_semantic else "" |  | ||||||
|  |  | ||||||
|     rename_keys = [] |  | ||||||
|     for i in range(config.num_hidden_layers): |  | ||||||
|         # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append( |  | ||||||
|             (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias") |  | ||||||
|         ) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) |  | ||||||
|         rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) |  | ||||||
|  |  | ||||||
|     # projection layer + position embeddings |  | ||||||
|     rename_keys.extend( |  | ||||||
|         [ |  | ||||||
|             (f"{prefix}cls_token", "beit.embeddings.cls_token"), |  | ||||||
|             (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), |  | ||||||
|             (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     if has_lm_head: |  | ||||||
|         # mask token + shared relative position bias + layernorm |  | ||||||
|         rename_keys.extend( |  | ||||||
|             [ |  | ||||||
|                 ("mask_token", "beit.embeddings.mask_token"), |  | ||||||
|                 ( |  | ||||||
|                     "rel_pos_bias.relative_position_bias_table", |  | ||||||
|                     "beit.encoder.relative_position_bias.relative_position_bias_table", |  | ||||||
|                 ), |  | ||||||
|                 ( |  | ||||||
|                     "rel_pos_bias.relative_position_index", |  | ||||||
|                     "beit.encoder.relative_position_bias.relative_position_index", |  | ||||||
|                 ), |  | ||||||
|                 ("norm.weight", "layernorm.weight"), |  | ||||||
|                 ("norm.bias", "layernorm.bias"), |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|     elif is_semantic: |  | ||||||
|         # semantic segmentation classification heads |  | ||||||
|         rename_keys.extend( |  | ||||||
|             [ |  | ||||||
|                 ("decode_head.conv_seg.weight", "decode_head.classifier.weight"), |  | ||||||
|                 ("decode_head.conv_seg.bias", "decode_head.classifier.bias"), |  | ||||||
|                 ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), |  | ||||||
|                 ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|     else: |  | ||||||
|         # layernorm + classification head |  | ||||||
|         rename_keys.extend( |  | ||||||
|             [ |  | ||||||
|                 ("fc_norm.weight", "beit.pooler.layernorm.weight"), |  | ||||||
|                 ("fc_norm.bias", "beit.pooler.layernorm.bias"), |  | ||||||
|                 ("head.weight", "classifier.weight"), |  | ||||||
|                 ("head.bias", "classifier.bias"), |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     return rename_keys |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # we split up the matrix of each encoder layer into queries, keys and values |  | ||||||
| def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False): |  | ||||||
|     for i in range(config.num_hidden_layers): |  | ||||||
|         prefix = "backbone." if is_semantic else "" |  | ||||||
|         # queries, keys and values |  | ||||||
|         in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") |  | ||||||
|         q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") |  | ||||||
|         v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") |  | ||||||
|  |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ |  | ||||||
|             : config.hidden_size, : |  | ||||||
|         ] |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ |  | ||||||
|             config.hidden_size : config.hidden_size * 2, : |  | ||||||
|         ] |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ |  | ||||||
|             -config.hidden_size :, : |  | ||||||
|         ] |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias |  | ||||||
|  |  | ||||||
|         # gamma_1 and gamma_2 |  | ||||||
|         # we call them lambda because otherwise they are renamed when using .from_pretrained |  | ||||||
|         gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") |  | ||||||
|         gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") |  | ||||||
|  |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1 |  | ||||||
|         state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2 |  | ||||||
|  |  | ||||||
|         # relative_position bias table + index |  | ||||||
|         if not has_lm_head: |  | ||||||
|             # each layer has its own relative position bias |  | ||||||
|             table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table") |  | ||||||
|             index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index") |  | ||||||
|  |  | ||||||
|             state_dict[ |  | ||||||
|                 f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" |  | ||||||
|             ] = table |  | ||||||
|             state_dict[ |  | ||||||
|                 f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" |  | ||||||
|             ] = index |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(dct, old, new): |  | ||||||
|     val = dct.pop(old) |  | ||||||
|     dct[new] = val |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our BEiT structure. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # define default BEiT configuration |  | ||||||
|     config = BeitConfig() |  | ||||||
|     has_lm_head = False |  | ||||||
|     is_semantic = False |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     # set config parameters based on URL |  | ||||||
|     if checkpoint_url[-9:-4] == "pt22k": |  | ||||||
|         # masked image modeling |  | ||||||
|         config.use_shared_relative_position_bias = True |  | ||||||
|         config.use_mask_token = True |  | ||||||
|         has_lm_head = True |  | ||||||
|     elif checkpoint_url[-9:-4] == "ft22k": |  | ||||||
|         # intermediate fine-tuning on ImageNet-22k |  | ||||||
|         config.use_relative_position_bias = True |  | ||||||
|         config.num_labels = 21841 |  | ||||||
|         filename = "imagenet-22k-id2label.json" |  | ||||||
|         id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|         id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|         # this dataset contains 21843 labels but the model only has 21841 |  | ||||||
|         # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 |  | ||||||
|         del id2label[9205] |  | ||||||
|         del id2label[15027] |  | ||||||
|         config.id2label = id2label |  | ||||||
|         config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|     elif checkpoint_url[-8:-4] == "to1k": |  | ||||||
|         # fine-tuning on ImageNet-1k |  | ||||||
|         config.use_relative_position_bias = True |  | ||||||
|         config.num_labels = 1000 |  | ||||||
|         filename = "imagenet-1k-id2label.json" |  | ||||||
|         id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|         id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|         config.id2label = id2label |  | ||||||
|         config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|         if "384" in checkpoint_url: |  | ||||||
|             config.image_size = 384 |  | ||||||
|         if "512" in checkpoint_url: |  | ||||||
|             config.image_size = 512 |  | ||||||
|     elif "ade20k" in checkpoint_url: |  | ||||||
|         # fine-tuning |  | ||||||
|         config.use_relative_position_bias = True |  | ||||||
|         config.num_labels = 150 |  | ||||||
|         filename = "ade20k-id2label.json" |  | ||||||
|         id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|         id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|         config.id2label = id2label |  | ||||||
|         config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|         config.image_size = 640 |  | ||||||
|         is_semantic = True |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'") |  | ||||||
|  |  | ||||||
|     # size of the architecture |  | ||||||
|     if "base" in checkpoint_url: |  | ||||||
|         pass |  | ||||||
|     elif "large" in checkpoint_url: |  | ||||||
|         config.hidden_size = 1024 |  | ||||||
|         config.intermediate_size = 4096 |  | ||||||
|         config.num_hidden_layers = 24 |  | ||||||
|         config.num_attention_heads = 16 |  | ||||||
|         if "ade20k" in checkpoint_url: |  | ||||||
|             config.image_size = 640 |  | ||||||
|             config.out_indices = [7, 11, 15, 23] |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Should either find 'base' or 'large' in checkpoint URL") |  | ||||||
|  |  | ||||||
|     # load state_dict of original model, remove and rename some keys |  | ||||||
|     state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True) |  | ||||||
|     state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"] |  | ||||||
|  |  | ||||||
|     rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic) |  | ||||||
|     for src, dest in rename_keys: |  | ||||||
|         rename_key(state_dict, src, dest) |  | ||||||
|     read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic) |  | ||||||
|     if is_semantic: |  | ||||||
|         # add prefix to decoder keys |  | ||||||
|         for key, val in state_dict.copy().items(): |  | ||||||
|             val = state_dict.pop(key) |  | ||||||
|             if key.startswith("backbone.fpn"): |  | ||||||
|                 key = key.replace("backbone.fpn", "fpn") |  | ||||||
|             state_dict[key] = val |  | ||||||
|  |  | ||||||
|     # load HuggingFace model |  | ||||||
|     if checkpoint_url[-9:-4] == "pt22k": |  | ||||||
|         model = BeitForMaskedImageModeling(config) |  | ||||||
|     elif "ade20k" in checkpoint_url: |  | ||||||
|         model = BeitForSemanticSegmentation(config) |  | ||||||
|     else: |  | ||||||
|         model = BeitForImageClassification(config) |  | ||||||
|     model.eval() |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|  |  | ||||||
|     # Check outputs on an image |  | ||||||
|     if is_semantic: |  | ||||||
|         image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False) |  | ||||||
|         ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True) |  | ||||||
|         image = Image.open(ds[0]["file"]) |  | ||||||
|     else: |  | ||||||
|         image_processor = BeitImageProcessor( |  | ||||||
|             size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False |  | ||||||
|         ) |  | ||||||
|         image = prepare_img() |  | ||||||
|  |  | ||||||
|     encoding = image_processor(images=image, return_tensors="pt") |  | ||||||
|     pixel_values = encoding["pixel_values"] |  | ||||||
|  |  | ||||||
|     outputs = model(pixel_values) |  | ||||||
|     logits = outputs.logits |  | ||||||
|  |  | ||||||
|     # verify logits |  | ||||||
|     expected_shape = torch.Size([1, 1000]) |  | ||||||
|     if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"): |  | ||||||
|         expected_shape = torch.Size([1, 196, 8192]) |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"): |  | ||||||
|         expected_shape = torch.Size([1, 196, 8192]) |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"): |  | ||||||
|         expected_shape = torch.Size([1, 21841]) |  | ||||||
|         expected_logits = torch.tensor([2.2288, 2.4671, 0.7395]) |  | ||||||
|         expected_class_idx = 2397 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"): |  | ||||||
|         expected_shape = torch.Size([1, 21841]) |  | ||||||
|         expected_logits = torch.tensor([1.6881, -0.2787, 0.5901]) |  | ||||||
|         expected_class_idx = 2396 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"): |  | ||||||
|         expected_logits = torch.tensor([0.1241, 0.0798, -0.6569]) |  | ||||||
|         expected_class_idx = 285 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"): |  | ||||||
|         expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108]) |  | ||||||
|         expected_class_idx = 281 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"): |  | ||||||
|         expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147]) |  | ||||||
|         expected_class_idx = 761 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"): |  | ||||||
|         expected_logits = torch.tensor([0.4610, -0.0928, 0.2086]) |  | ||||||
|         expected_class_idx = 761 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"): |  | ||||||
|         expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837]) |  | ||||||
|         expected_class_idx = 761 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"): |  | ||||||
|         expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]]) |  | ||||||
|         expected_class_idx = 761 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): |  | ||||||
|         expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) |  | ||||||
|         expected_class_idx = 761 |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"): |  | ||||||
|         expected_shape = (1, 150, 160, 160) |  | ||||||
|         expected_logits = torch.tensor( |  | ||||||
|             [ |  | ||||||
|                 [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], |  | ||||||
|                 [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], |  | ||||||
|                 [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|     elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"): |  | ||||||
|         expected_shape = (1, 150, 160, 160) |  | ||||||
|         expected_logits = torch.tensor( |  | ||||||
|             [ |  | ||||||
|                 [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]], |  | ||||||
|                 [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]], |  | ||||||
|                 [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]], |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Can't verify logits as model is not supported") |  | ||||||
|  |  | ||||||
|     if logits.shape != expected_shape: |  | ||||||
|         raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}") |  | ||||||
|     if not has_lm_head: |  | ||||||
|         if is_semantic: |  | ||||||
|             if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3): |  | ||||||
|                 raise ValueError("First elements of logits not as expected") |  | ||||||
|         else: |  | ||||||
|             print("Predicted class idx:", logits.argmax(-1).item()) |  | ||||||
|  |  | ||||||
|             if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3): |  | ||||||
|                 raise ValueError("First elements of logits not as expected") |  | ||||||
|             if logits.argmax(-1).item() != expected_class_idx: |  | ||||||
|                 raise ValueError("Predicted class index not as expected") |  | ||||||
|  |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     print(f"Saving model to {pytorch_dump_folder_path}") |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     print(f"Saving image processor to {pytorch_dump_folder_path}") |  | ||||||
|     image_processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|  |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_url", |  | ||||||
|         default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth", |  | ||||||
|         type=str, |  | ||||||
|         help="URL to the original PyTorch checkpoint (.pth file).", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,246 +0,0 @@ | |||||||
| # Copyright 2020 The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """ |  | ||||||
| This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now |  | ||||||
| deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert |  | ||||||
|  |  | ||||||
| TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert |  | ||||||
| weight names to the original names, so the model can be imported with Huggingface/transformer. |  | ||||||
|  |  | ||||||
| You may adapt this script to include classification/MLM/NSP/etc. heads. |  | ||||||
|  |  | ||||||
| Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0). |  | ||||||
|       Models trained with never versions are not compatible with this script. |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| import tensorflow as tf |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BertConfig, BertModel |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_tf2_weights_in_bert(model, tf_checkpoint_path, config): |  | ||||||
|     tf_path = os.path.abspath(tf_checkpoint_path) |  | ||||||
|     logger.info(f"Converting TensorFlow checkpoint from {tf_path}") |  | ||||||
|     # Load weights from TF model |  | ||||||
|     init_vars = tf.train.list_variables(tf_path) |  | ||||||
|     names = [] |  | ||||||
|     arrays = [] |  | ||||||
|     layer_depth = [] |  | ||||||
|     for full_name, shape in init_vars: |  | ||||||
|         # logger.info(f"Loading TF weight {name} with shape {shape}") |  | ||||||
|         name = full_name.split("/") |  | ||||||
|         if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]: |  | ||||||
|             logger.info(f"Skipping non-model layer {full_name}") |  | ||||||
|             continue |  | ||||||
|         if "optimizer" in full_name: |  | ||||||
|             logger.info(f"Skipping optimization layer {full_name}") |  | ||||||
|             continue |  | ||||||
|         if name[0] == "model": |  | ||||||
|             # ignore initial 'model' |  | ||||||
|             name = name[1:] |  | ||||||
|         # figure out how many levels deep the name is |  | ||||||
|         depth = 0 |  | ||||||
|         for _name in name: |  | ||||||
|             if _name.startswith("layer_with_weights"): |  | ||||||
|                 depth += 1 |  | ||||||
|             else: |  | ||||||
|                 break |  | ||||||
|         layer_depth.append(depth) |  | ||||||
|         # read data |  | ||||||
|         array = tf.train.load_variable(tf_path, full_name) |  | ||||||
|         names.append("/".join(name)) |  | ||||||
|         arrays.append(array) |  | ||||||
|     logger.info(f"Read a total of {len(arrays):,} layers") |  | ||||||
|  |  | ||||||
|     # Sanity check |  | ||||||
|     if len(set(layer_depth)) != 1: |  | ||||||
|         raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})") |  | ||||||
|     layer_depth = list(set(layer_depth))[0] |  | ||||||
|     if layer_depth != 1: |  | ||||||
|         raise ValueError( |  | ||||||
|             "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP" |  | ||||||
|             " heads." |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     # convert layers |  | ||||||
|     logger.info("Converting weights...") |  | ||||||
|     for full_name, array in zip(names, arrays): |  | ||||||
|         name = full_name.split("/") |  | ||||||
|         pointer = model |  | ||||||
|         trace = [] |  | ||||||
|         for i, m_name in enumerate(name): |  | ||||||
|             if m_name == ".ATTRIBUTES": |  | ||||||
|                 # variable names end with .ATTRIBUTES/VARIABLE_VALUE |  | ||||||
|                 break |  | ||||||
|             if m_name.startswith("layer_with_weights"): |  | ||||||
|                 layer_num = int(m_name.split("-")[-1]) |  | ||||||
|                 if layer_num <= 2: |  | ||||||
|                     # embedding layers |  | ||||||
|                     # layer_num 0: word_embeddings |  | ||||||
|                     # layer_num 1: position_embeddings |  | ||||||
|                     # layer_num 2: token_type_embeddings |  | ||||||
|                     continue |  | ||||||
|                 elif layer_num == 3: |  | ||||||
|                     # embedding LayerNorm |  | ||||||
|                     trace.extend(["embeddings", "LayerNorm"]) |  | ||||||
|                     pointer = getattr(pointer, "embeddings") |  | ||||||
|                     pointer = getattr(pointer, "LayerNorm") |  | ||||||
|                 elif layer_num > 3 and layer_num < config.num_hidden_layers + 4: |  | ||||||
|                     # encoder layers |  | ||||||
|                     trace.extend(["encoder", "layer", str(layer_num - 4)]) |  | ||||||
|                     pointer = getattr(pointer, "encoder") |  | ||||||
|                     pointer = getattr(pointer, "layer") |  | ||||||
|                     pointer = pointer[layer_num - 4] |  | ||||||
|                 elif layer_num == config.num_hidden_layers + 4: |  | ||||||
|                     # pooler layer |  | ||||||
|                     trace.extend(["pooler", "dense"]) |  | ||||||
|                     pointer = getattr(pointer, "pooler") |  | ||||||
|                     pointer = getattr(pointer, "dense") |  | ||||||
|             elif m_name == "embeddings": |  | ||||||
|                 trace.append("embeddings") |  | ||||||
|                 pointer = getattr(pointer, "embeddings") |  | ||||||
|                 if layer_num == 0: |  | ||||||
|                     trace.append("word_embeddings") |  | ||||||
|                     pointer = getattr(pointer, "word_embeddings") |  | ||||||
|                 elif layer_num == 1: |  | ||||||
|                     trace.append("position_embeddings") |  | ||||||
|                     pointer = getattr(pointer, "position_embeddings") |  | ||||||
|                 elif layer_num == 2: |  | ||||||
|                     trace.append("token_type_embeddings") |  | ||||||
|                     pointer = getattr(pointer, "token_type_embeddings") |  | ||||||
|                 else: |  | ||||||
|                     raise ValueError(f"Unknown embedding layer with name {full_name}") |  | ||||||
|                 trace.append("weight") |  | ||||||
|                 pointer = getattr(pointer, "weight") |  | ||||||
|             elif m_name == "_attention_layer": |  | ||||||
|                 # self-attention layer |  | ||||||
|                 trace.extend(["attention", "self"]) |  | ||||||
|                 pointer = getattr(pointer, "attention") |  | ||||||
|                 pointer = getattr(pointer, "self") |  | ||||||
|             elif m_name == "_attention_layer_norm": |  | ||||||
|                 # output attention norm |  | ||||||
|                 trace.extend(["attention", "output", "LayerNorm"]) |  | ||||||
|                 pointer = getattr(pointer, "attention") |  | ||||||
|                 pointer = getattr(pointer, "output") |  | ||||||
|                 pointer = getattr(pointer, "LayerNorm") |  | ||||||
|             elif m_name == "_attention_output_dense": |  | ||||||
|                 # output attention dense |  | ||||||
|                 trace.extend(["attention", "output", "dense"]) |  | ||||||
|                 pointer = getattr(pointer, "attention") |  | ||||||
|                 pointer = getattr(pointer, "output") |  | ||||||
|                 pointer = getattr(pointer, "dense") |  | ||||||
|             elif m_name == "_output_dense": |  | ||||||
|                 # output dense |  | ||||||
|                 trace.extend(["output", "dense"]) |  | ||||||
|                 pointer = getattr(pointer, "output") |  | ||||||
|                 pointer = getattr(pointer, "dense") |  | ||||||
|             elif m_name == "_output_layer_norm": |  | ||||||
|                 # output dense |  | ||||||
|                 trace.extend(["output", "LayerNorm"]) |  | ||||||
|                 pointer = getattr(pointer, "output") |  | ||||||
|                 pointer = getattr(pointer, "LayerNorm") |  | ||||||
|             elif m_name == "_key_dense": |  | ||||||
|                 # attention key |  | ||||||
|                 trace.append("key") |  | ||||||
|                 pointer = getattr(pointer, "key") |  | ||||||
|             elif m_name == "_query_dense": |  | ||||||
|                 # attention query |  | ||||||
|                 trace.append("query") |  | ||||||
|                 pointer = getattr(pointer, "query") |  | ||||||
|             elif m_name == "_value_dense": |  | ||||||
|                 # attention value |  | ||||||
|                 trace.append("value") |  | ||||||
|                 pointer = getattr(pointer, "value") |  | ||||||
|             elif m_name == "_intermediate_dense": |  | ||||||
|                 # attention intermediate dense |  | ||||||
|                 trace.extend(["intermediate", "dense"]) |  | ||||||
|                 pointer = getattr(pointer, "intermediate") |  | ||||||
|                 pointer = getattr(pointer, "dense") |  | ||||||
|             elif m_name == "_output_layer_norm": |  | ||||||
|                 # output layer norm |  | ||||||
|                 trace.append("output") |  | ||||||
|                 pointer = getattr(pointer, "output") |  | ||||||
|             # weights & biases |  | ||||||
|             elif m_name in ["bias", "beta"]: |  | ||||||
|                 trace.append("bias") |  | ||||||
|                 pointer = getattr(pointer, "bias") |  | ||||||
|             elif m_name in ["kernel", "gamma"]: |  | ||||||
|                 trace.append("weight") |  | ||||||
|                 pointer = getattr(pointer, "weight") |  | ||||||
|             else: |  | ||||||
|                 logger.warning(f"Ignored {m_name}") |  | ||||||
|         # for certain layers reshape is necessary |  | ||||||
|         trace = ".".join(trace) |  | ||||||
|         if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match( |  | ||||||
|             r"(\S+)\.attention\.output\.dense\.weight", trace |  | ||||||
|         ): |  | ||||||
|             array = array.reshape(pointer.data.shape) |  | ||||||
|         if "kernel" in full_name: |  | ||||||
|             array = array.transpose() |  | ||||||
|         if pointer.shape == array.shape: |  | ||||||
|             pointer.data = torch.from_numpy(array) |  | ||||||
|         else: |  | ||||||
|             raise ValueError( |  | ||||||
|                 f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:" |  | ||||||
|                 f" {array.shape}" |  | ||||||
|             ) |  | ||||||
|         logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}") |  | ||||||
|     return model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path): |  | ||||||
|     # Instantiate model |  | ||||||
|     logger.info(f"Loading model based on config from {config_path}...") |  | ||||||
|     config = BertConfig.from_json_file(config_path) |  | ||||||
|     model = BertModel(config) |  | ||||||
|  |  | ||||||
|     # Load weights from checkpoint |  | ||||||
|     logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...") |  | ||||||
|     load_tf2_weights_in_bert(model, tf_checkpoint_path, config) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model |  | ||||||
|     logger.info(f"Saving PyTorch model to {pytorch_dump_path}...") |  | ||||||
|     torch.save(model.state_dict(), pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--bert_config_file", |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="The config json file corresponding to the BERT model. This specifies the model architecture.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to the output PyTorch model (must include filename).", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,62 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2018 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BERT checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): |  | ||||||
|     # Initialise PyTorch model |  | ||||||
|     config = BertConfig.from_json_file(bert_config_file) |  | ||||||
|     print(f"Building PyTorch model from configuration: {config}") |  | ||||||
|     model = BertForPreTraining(config) |  | ||||||
|  |  | ||||||
|     # Load weights from tf checkpoint |  | ||||||
|     load_tf_weights_in_bert(model, config, tf_checkpoint_path) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model |  | ||||||
|     print(f"Save PyTorch model to {pytorch_dump_path}") |  | ||||||
|     torch.save(model.state_dict(), pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--bert_config_file", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "The config json file corresponding to the pre-trained BERT model. \n" |  | ||||||
|             "This specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,112 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2018 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import numpy as np |  | ||||||
| import tensorflow as tf |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BertModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): |  | ||||||
|     """ |  | ||||||
|     Args: |  | ||||||
|         model: BertModel Pytorch model instance to be converted |  | ||||||
|         ckpt_dir: Tensorflow model directory |  | ||||||
|         model_name: model name |  | ||||||
|  |  | ||||||
|     Currently supported HF models: |  | ||||||
|  |  | ||||||
|         - Y BertModel |  | ||||||
|         - N BertForMaskedLM |  | ||||||
|         - N BertForPreTraining |  | ||||||
|         - N BertForMultipleChoice |  | ||||||
|         - N BertForNextSentencePrediction |  | ||||||
|         - N BertForSequenceClassification |  | ||||||
|         - N BertForQuestionAnswering |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value") |  | ||||||
|  |  | ||||||
|     var_map = ( |  | ||||||
|         ("layer.", "layer_"), |  | ||||||
|         ("word_embeddings.weight", "word_embeddings"), |  | ||||||
|         ("position_embeddings.weight", "position_embeddings"), |  | ||||||
|         ("token_type_embeddings.weight", "token_type_embeddings"), |  | ||||||
|         (".", "/"), |  | ||||||
|         ("LayerNorm/weight", "LayerNorm/gamma"), |  | ||||||
|         ("LayerNorm/bias", "LayerNorm/beta"), |  | ||||||
|         ("weight", "kernel"), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     if not os.path.isdir(ckpt_dir): |  | ||||||
|         os.makedirs(ckpt_dir) |  | ||||||
|  |  | ||||||
|     state_dict = model.state_dict() |  | ||||||
|  |  | ||||||
|     def to_tf_var_name(name: str): |  | ||||||
|         for patt, repl in iter(var_map): |  | ||||||
|             name = name.replace(patt, repl) |  | ||||||
|         return f"bert/{name}" |  | ||||||
|  |  | ||||||
|     def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session): |  | ||||||
|         tf_dtype = tf.dtypes.as_dtype(tensor.dtype) |  | ||||||
|         tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) |  | ||||||
|         session.run(tf.variables_initializer([tf_var])) |  | ||||||
|         session.run(tf_var) |  | ||||||
|         return tf_var |  | ||||||
|  |  | ||||||
|     tf.reset_default_graph() |  | ||||||
|     with tf.Session() as session: |  | ||||||
|         for var_name in state_dict: |  | ||||||
|             tf_name = to_tf_var_name(var_name) |  | ||||||
|             torch_tensor = state_dict[var_name].numpy() |  | ||||||
|             if any(x in var_name for x in tensors_to_transpose): |  | ||||||
|                 torch_tensor = torch_tensor.T |  | ||||||
|             tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) |  | ||||||
|             tf_var.assign(tf.cast(torch_tensor, tf_var.dtype)) |  | ||||||
|             tf_weight = session.run(tf_var) |  | ||||||
|             print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}") |  | ||||||
|  |  | ||||||
|         saver = tf.train.Saver(tf.trainable_variables()) |  | ||||||
|         saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(raw_args=None): |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model" |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin") |  | ||||||
|     parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model") |  | ||||||
|     args = parser.parse_args(raw_args) |  | ||||||
|  |  | ||||||
|     model = BertModel.from_pretrained( |  | ||||||
|         pretrained_model_name_or_path=args.model_name, |  | ||||||
|         state_dict=torch.load(args.pytorch_model_path), |  | ||||||
|         cache_dir=args.cache_dir, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     main() |  | ||||||
| @ -1,188 +0,0 @@ | |||||||
| # Copyright 2022 The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """ |  | ||||||
| This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT |  | ||||||
| model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository: |  | ||||||
|  |  | ||||||
| https://github.com/tensorflow/models/tree/master/official/projects/token_dropping |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import tensorflow as tf |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BertConfig, BertForMaskedLM |  | ||||||
| from transformers.models.bert.modeling_bert import ( |  | ||||||
|     BertIntermediate, |  | ||||||
|     BertLayer, |  | ||||||
|     BertOutput, |  | ||||||
|     BertPooler, |  | ||||||
|     BertSelfAttention, |  | ||||||
|     BertSelfOutput, |  | ||||||
| ) |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str): |  | ||||||
|     def get_masked_lm_array(name: str): |  | ||||||
|         full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE" |  | ||||||
|         array = tf.train.load_variable(tf_checkpoint_path, full_name) |  | ||||||
|  |  | ||||||
|         if "kernel" in name: |  | ||||||
|             array = array.transpose() |  | ||||||
|  |  | ||||||
|         return torch.from_numpy(array) |  | ||||||
|  |  | ||||||
|     def get_encoder_array(name: str): |  | ||||||
|         full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE" |  | ||||||
|         array = tf.train.load_variable(tf_checkpoint_path, full_name) |  | ||||||
|  |  | ||||||
|         if "kernel" in name: |  | ||||||
|             array = array.transpose() |  | ||||||
|  |  | ||||||
|         return torch.from_numpy(array) |  | ||||||
|  |  | ||||||
|     def get_encoder_layer_array(layer_index: int, name: str): |  | ||||||
|         full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE" |  | ||||||
|         array = tf.train.load_variable(tf_checkpoint_path, full_name) |  | ||||||
|  |  | ||||||
|         if "kernel" in name: |  | ||||||
|             array = array.transpose() |  | ||||||
|  |  | ||||||
|         return torch.from_numpy(array) |  | ||||||
|  |  | ||||||
|     def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape): |  | ||||||
|         full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE" |  | ||||||
|         array = tf.train.load_variable(tf_checkpoint_path, full_name) |  | ||||||
|         array = array.reshape(orginal_shape) |  | ||||||
|  |  | ||||||
|         if "kernel" in name: |  | ||||||
|             array = array.transpose() |  | ||||||
|  |  | ||||||
|         return torch.from_numpy(array) |  | ||||||
|  |  | ||||||
|     print(f"Loading model based on config from {config_path}...") |  | ||||||
|     config = BertConfig.from_json_file(config_path) |  | ||||||
|     model = BertForMaskedLM(config) |  | ||||||
|  |  | ||||||
|     # Layers |  | ||||||
|     for layer_index in range(0, config.num_hidden_layers): |  | ||||||
|         layer: BertLayer = model.bert.encoder.layer[layer_index] |  | ||||||
|  |  | ||||||
|         # Self-attention |  | ||||||
|         self_attn: BertSelfAttention = layer.attention.self |  | ||||||
|  |  | ||||||
|         self_attn.query.weight.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape |  | ||||||
|         ) |  | ||||||
|         self_attn.query.bias.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_query_dense/bias", self_attn.query.bias.data.shape |  | ||||||
|         ) |  | ||||||
|         self_attn.key.weight.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape |  | ||||||
|         ) |  | ||||||
|         self_attn.key.bias.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_key_dense/bias", self_attn.key.bias.data.shape |  | ||||||
|         ) |  | ||||||
|         self_attn.value.weight.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape |  | ||||||
|         ) |  | ||||||
|         self_attn.value.bias.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_value_dense/bias", self_attn.value.bias.data.shape |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         # Self-attention Output |  | ||||||
|         self_output: BertSelfOutput = layer.attention.output |  | ||||||
|  |  | ||||||
|         self_output.dense.weight.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape |  | ||||||
|         ) |  | ||||||
|         self_output.dense.bias.data = get_encoder_attention_layer_array( |  | ||||||
|             layer_index, "_output_dense/bias", self_output.dense.bias.data.shape |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma") |  | ||||||
|         self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta") |  | ||||||
|  |  | ||||||
|         # Intermediate |  | ||||||
|         intermediate: BertIntermediate = layer.intermediate |  | ||||||
|  |  | ||||||
|         intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel") |  | ||||||
|         intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias") |  | ||||||
|  |  | ||||||
|         # Output |  | ||||||
|         bert_output: BertOutput = layer.output |  | ||||||
|  |  | ||||||
|         bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel") |  | ||||||
|         bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias") |  | ||||||
|  |  | ||||||
|         bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma") |  | ||||||
|         bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta") |  | ||||||
|  |  | ||||||
|     # Embeddings |  | ||||||
|     model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings") |  | ||||||
|     model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings") |  | ||||||
|     model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma") |  | ||||||
|     model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta") |  | ||||||
|  |  | ||||||
|     # LM Head |  | ||||||
|     lm_head = model.cls.predictions.transform |  | ||||||
|  |  | ||||||
|     lm_head.dense.weight.data = get_masked_lm_array("dense/kernel") |  | ||||||
|     lm_head.dense.bias.data = get_masked_lm_array("dense/bias") |  | ||||||
|  |  | ||||||
|     lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma") |  | ||||||
|     lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta") |  | ||||||
|  |  | ||||||
|     model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table") |  | ||||||
|  |  | ||||||
|     # Pooling |  | ||||||
|     model.bert.pooler = BertPooler(config=config) |  | ||||||
|     model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel") |  | ||||||
|     model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias") |  | ||||||
|  |  | ||||||
|     # Export final model |  | ||||||
|     model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|     # Integration test - should load without any errors ;) |  | ||||||
|     new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path) |  | ||||||
|     print(new_model.eval()) |  | ||||||
|  |  | ||||||
|     print("Model conversion was done sucessfully!") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--bert_config_file", |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="The config json file corresponding to the BERT model. This specifies the model architecture.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to the output PyTorch model.", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,69 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BigBird checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa): |  | ||||||
|     # Initialise PyTorch model |  | ||||||
|     config = BigBirdConfig.from_json_file(big_bird_config_file) |  | ||||||
|     print(f"Building PyTorch model from configuration: {config}") |  | ||||||
|  |  | ||||||
|     if is_trivia_qa: |  | ||||||
|         model = BigBirdForQuestionAnswering(config) |  | ||||||
|     else: |  | ||||||
|         model = BigBirdForPreTraining(config) |  | ||||||
|  |  | ||||||
|     # Load weights from tf checkpoint |  | ||||||
|     load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model |  | ||||||
|     print(f"Save PyTorch model to {pytorch_dump_path}") |  | ||||||
|     model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--big_bird_config_file", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "The config json file corresponding to the pre-trained BERT model. \n" |  | ||||||
|             "This specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf_checkpoint_to_pytorch( |  | ||||||
|         args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa |  | ||||||
|     ) |  | ||||||
| @ -1,170 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| from typing import Dict |  | ||||||
|  |  | ||||||
| import tensorflow as tf |  | ||||||
| import torch |  | ||||||
| from tqdm import tqdm |  | ||||||
|  |  | ||||||
| from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration |  | ||||||
|  |  | ||||||
|  |  | ||||||
| INIT_COMMON = [ |  | ||||||
|     # tf -> hf |  | ||||||
|     ("/", "."), |  | ||||||
|     ("layer_", "layers."), |  | ||||||
|     ("kernel", "weight"), |  | ||||||
|     ("beta", "bias"), |  | ||||||
|     ("gamma", "weight"), |  | ||||||
|     ("pegasus", "model"), |  | ||||||
| ] |  | ||||||
| END_COMMON = [ |  | ||||||
|     (".output.dense", ".fc2"), |  | ||||||
|     ("intermediate.LayerNorm", "final_layer_norm"), |  | ||||||
|     ("intermediate.dense", "fc1"), |  | ||||||
| ] |  | ||||||
|  |  | ||||||
| DECODER_PATTERNS = ( |  | ||||||
|     INIT_COMMON |  | ||||||
|     + [ |  | ||||||
|         ("attention.self.LayerNorm", "self_attn_layer_norm"), |  | ||||||
|         ("attention.output.dense", "self_attn.out_proj"), |  | ||||||
|         ("attention.self", "self_attn"), |  | ||||||
|         ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"), |  | ||||||
|         ("attention.encdec_output.dense", "encoder_attn.out_proj"), |  | ||||||
|         ("attention.encdec", "encoder_attn"), |  | ||||||
|         ("key", "k_proj"), |  | ||||||
|         ("value", "v_proj"), |  | ||||||
|         ("query", "q_proj"), |  | ||||||
|         ("decoder.LayerNorm", "decoder.layernorm_embedding"), |  | ||||||
|     ] |  | ||||||
|     + END_COMMON |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| REMAINING_PATTERNS = ( |  | ||||||
|     INIT_COMMON |  | ||||||
|     + [ |  | ||||||
|         ("embeddings.word_embeddings", "shared.weight"), |  | ||||||
|         ("embeddings.position_embeddings", "embed_positions.weight"), |  | ||||||
|         ("attention.self.LayerNorm", "self_attn_layer_norm"), |  | ||||||
|         ("attention.output.dense", "self_attn.output"), |  | ||||||
|         ("attention.self", "self_attn.self"), |  | ||||||
|         ("encoder.LayerNorm", "encoder.layernorm_embedding"), |  | ||||||
|     ] |  | ||||||
|     + END_COMMON |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| KEYS_TO_IGNORE = [ |  | ||||||
|     "encdec/key/bias", |  | ||||||
|     "encdec/query/bias", |  | ||||||
|     "encdec/value/bias", |  | ||||||
|     "self/key/bias", |  | ||||||
|     "self/query/bias", |  | ||||||
|     "self/value/bias", |  | ||||||
|     "encdec_output/dense/bias", |  | ||||||
|     "attention/output/dense/bias", |  | ||||||
| ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_state_dict_key(k, patterns): |  | ||||||
|     for tf_name, hf_name in patterns: |  | ||||||
|         k = k.replace(tf_name, hf_name) |  | ||||||
|     return k |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration: |  | ||||||
|     cfg = BigBirdPegasusConfig(**config_update) |  | ||||||
|     torch_model = BigBirdPegasusForConditionalGeneration(cfg) |  | ||||||
|     state_dict = torch_model.state_dict() |  | ||||||
|     mapping = {} |  | ||||||
|  |  | ||||||
|     # separating decoder weights |  | ||||||
|     decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")} |  | ||||||
|     remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")} |  | ||||||
|  |  | ||||||
|     for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"): |  | ||||||
|         conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] |  | ||||||
|         if any(conditions): |  | ||||||
|             continue |  | ||||||
|         patterns = DECODER_PATTERNS |  | ||||||
|         new_k = rename_state_dict_key(k, patterns) |  | ||||||
|         if new_k not in state_dict: |  | ||||||
|             raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") |  | ||||||
|         if any(True if i in k else False for i in ["dense", "query", "key", "value"]): |  | ||||||
|             v = v.T |  | ||||||
|         mapping[new_k] = torch.from_numpy(v) |  | ||||||
|         assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" |  | ||||||
|  |  | ||||||
|     for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"): |  | ||||||
|         conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] |  | ||||||
|         if any(conditions): |  | ||||||
|             continue |  | ||||||
|         patterns = REMAINING_PATTERNS |  | ||||||
|         new_k = rename_state_dict_key(k, patterns) |  | ||||||
|         if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings": |  | ||||||
|             raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") |  | ||||||
|         if any(True if i in k else False for i in ["dense", "query", "key", "value"]): |  | ||||||
|             v = v.T |  | ||||||
|         mapping[new_k] = torch.from_numpy(v) |  | ||||||
|         if k != "pegasus/embeddings/position_embeddings": |  | ||||||
|             assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" |  | ||||||
|  |  | ||||||
|     mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"] |  | ||||||
|     mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight") |  | ||||||
|     missing, extra = torch_model.load_state_dict(mapping, strict=False) |  | ||||||
|     unexpected_missing = [ |  | ||||||
|         k |  | ||||||
|         for k in missing |  | ||||||
|         if k |  | ||||||
|         not in [ |  | ||||||
|             "final_logits_bias", |  | ||||||
|             "model.encoder.embed_tokens.weight", |  | ||||||
|             "model.decoder.embed_tokens.weight", |  | ||||||
|             "lm_head.weight", |  | ||||||
|         ] |  | ||||||
|     ] |  | ||||||
|     assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" |  | ||||||
|     assert extra == [], f"no matches found for the following tf keys {extra}" |  | ||||||
|     return torch_model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_tf_weights_as_numpy(path) -> Dict: |  | ||||||
|     init_vars = tf.train.list_variables(path) |  | ||||||
|     tf_weights = {} |  | ||||||
|     ignore_name = ["global_step"] |  | ||||||
|     for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): |  | ||||||
|         skip_key = any(pat in name for pat in ignore_name) |  | ||||||
|         if skip_key: |  | ||||||
|             continue |  | ||||||
|         array = tf.train.load_variable(path, name) |  | ||||||
|         tf_weights[name] = array |  | ||||||
|     return tf_weights |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict): |  | ||||||
|     tf_weights = get_tf_weights_as_numpy(ckpt_path) |  | ||||||
|     torch_model = convert_bigbird_pegasus(tf_weights, config_update) |  | ||||||
|     torch_model.save_pretrained(save_dir) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables") |  | ||||||
|     parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     config_update = {} |  | ||||||
|     convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update) |  | ||||||
| @ -1,292 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| import os |  | ||||||
| import re |  | ||||||
| import shutil |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BioGptConfig, BioGptForCausalLM |  | ||||||
| from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES |  | ||||||
| from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE |  | ||||||
| from transformers.utils import WEIGHTS_NAME, logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_warning() |  | ||||||
|  |  | ||||||
| json_indent = 2 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18 |  | ||||||
| class Dictionary: |  | ||||||
|     """A mapping from symbols to consecutive integers""" |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         *,  # begin keyword-only arguments |  | ||||||
|         bos="<s>", |  | ||||||
|         pad="<pad>", |  | ||||||
|         eos="</s>", |  | ||||||
|         unk="<unk>", |  | ||||||
|         extra_special_symbols=None, |  | ||||||
|     ): |  | ||||||
|         self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos |  | ||||||
|         self.symbols = [] |  | ||||||
|         self.count = [] |  | ||||||
|         self.indices = {} |  | ||||||
|         self.bos_index = self.add_symbol(bos) |  | ||||||
|         self.pad_index = self.add_symbol(pad) |  | ||||||
|         self.eos_index = self.add_symbol(eos) |  | ||||||
|         self.unk_index = self.add_symbol(unk) |  | ||||||
|         if extra_special_symbols: |  | ||||||
|             for s in extra_special_symbols: |  | ||||||
|                 self.add_symbol(s) |  | ||||||
|         self.nspecial = len(self.symbols) |  | ||||||
|  |  | ||||||
|     def __eq__(self, other): |  | ||||||
|         return self.indices == other.indices |  | ||||||
|  |  | ||||||
|     def __getitem__(self, idx): |  | ||||||
|         if idx < len(self.symbols): |  | ||||||
|             return self.symbols[idx] |  | ||||||
|         return self.unk_word |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         """Returns the number of symbols in the dictionary""" |  | ||||||
|         return len(self.symbols) |  | ||||||
|  |  | ||||||
|     def __contains__(self, sym): |  | ||||||
|         return sym in self.indices |  | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def load(cls, f): |  | ||||||
|         """Loads the dictionary from a text file with the format: |  | ||||||
|  |  | ||||||
|         ``` |  | ||||||
|         <symbol0> <count0> |  | ||||||
|         <symbol1> <count1> |  | ||||||
|         ... |  | ||||||
|         ``` |  | ||||||
|         """ |  | ||||||
|         d = cls() |  | ||||||
|         d.add_from_file(f) |  | ||||||
|         return d |  | ||||||
|  |  | ||||||
|     def add_symbol(self, word, n=1, overwrite=False): |  | ||||||
|         """Adds a word to the dictionary""" |  | ||||||
|         if word in self.indices and not overwrite: |  | ||||||
|             idx = self.indices[word] |  | ||||||
|             self.count[idx] = self.count[idx] + n |  | ||||||
|             return idx |  | ||||||
|         else: |  | ||||||
|             idx = len(self.symbols) |  | ||||||
|             self.indices[word] = idx |  | ||||||
|             self.symbols.append(word) |  | ||||||
|             self.count.append(n) |  | ||||||
|             return idx |  | ||||||
|  |  | ||||||
|     def _load_meta(self, lines): |  | ||||||
|         return 0 |  | ||||||
|  |  | ||||||
|     def add_from_file(self, f): |  | ||||||
|         """ |  | ||||||
|         Loads a pre-existing dictionary from a text file and adds its symbols to this instance. |  | ||||||
|         """ |  | ||||||
|         if isinstance(f, str): |  | ||||||
|             try: |  | ||||||
|                 with open(f, "r", encoding="utf-8") as fd: |  | ||||||
|                     self.add_from_file(fd) |  | ||||||
|             except FileNotFoundError as fnfe: |  | ||||||
|                 raise fnfe |  | ||||||
|             except UnicodeError: |  | ||||||
|                 raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f)) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         lines = f.readlines() |  | ||||||
|         indices_start_line = self._load_meta(lines) |  | ||||||
|  |  | ||||||
|         for line in lines[indices_start_line:]: |  | ||||||
|             try: |  | ||||||
|                 line, field = line.rstrip().rsplit(" ", 1) |  | ||||||
|                 if field == "#fairseq:overwrite": |  | ||||||
|                     overwrite = True |  | ||||||
|                     line, field = line.rsplit(" ", 1) |  | ||||||
|                 else: |  | ||||||
|                     overwrite = False |  | ||||||
|                 count = int(field) |  | ||||||
|                 word = line |  | ||||||
|                 if word in self and not overwrite: |  | ||||||
|                     raise RuntimeError( |  | ||||||
|                         "Duplicate word found when loading Dictionary: '{}'. " |  | ||||||
|                         "Duplicate words can overwrite earlier ones by adding the " |  | ||||||
|                         "#fairseq:overwrite flag at the end of the corresponding row " |  | ||||||
|                         "in the dictionary file. If using the Camembert model, please " |  | ||||||
|                         "download an updated copy of the model file.".format(word) |  | ||||||
|                     ) |  | ||||||
|                 self.add_symbol(word, n=count, overwrite=overwrite) |  | ||||||
|             except ValueError: |  | ||||||
|                 raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rewrite_dict_keys(d): |  | ||||||
|     # (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up, |  | ||||||
|     # e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7} |  | ||||||
|     d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items()) |  | ||||||
|     keep_keys = "<s> <pad> </s> <unk>".split() |  | ||||||
|     # restore the special tokens |  | ||||||
|     for k in keep_keys: |  | ||||||
|         del d2[f"{k}</w>"] |  | ||||||
|         d2[k] = d[k]  # restore |  | ||||||
|     return d2 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path): |  | ||||||
|     # prep |  | ||||||
|     if not os.path.exists(biogpt_checkpoint_path): |  | ||||||
|         raise ValueError(f"path {biogpt_checkpoint_path} does not exist!") |  | ||||||
|     os.makedirs(pytorch_dump_folder_path, exist_ok=True) |  | ||||||
|     print(f"Writing results to {pytorch_dump_folder_path}") |  | ||||||
|  |  | ||||||
|     # handle various types of models |  | ||||||
|  |  | ||||||
|     checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt") |  | ||||||
|     if not os.path.isfile(checkpoint_file): |  | ||||||
|         raise ValueError(f"path to the file {checkpoint_file} does not exist!") |  | ||||||
|     chkpt = torch.load(checkpoint_file, map_location="cpu") |  | ||||||
|  |  | ||||||
|     args = chkpt["cfg"]["model"] |  | ||||||
|  |  | ||||||
|     # dicts |  | ||||||
|     dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt") |  | ||||||
|     if not os.path.isfile(dict_file): |  | ||||||
|         raise ValueError(f"path to the file {dict_file} does not exist!") |  | ||||||
|     src_dict = Dictionary.load(dict_file) |  | ||||||
|     src_vocab = rewrite_dict_keys(src_dict.indices) |  | ||||||
|     src_vocab_size = len(src_vocab) |  | ||||||
|     src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"]) |  | ||||||
|     print(f"Generating {src_vocab_file} of {src_vocab_size} records") |  | ||||||
|     with open(src_vocab_file, "w", encoding="utf-8") as f: |  | ||||||
|         f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent)) |  | ||||||
|  |  | ||||||
|     # merges_file (bpecodes) |  | ||||||
|     bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes") |  | ||||||
|     if not os.path.isfile(bpecodes_file): |  | ||||||
|         raise ValueError(f"path to the file {bpecodes_file} does not exist!") |  | ||||||
|  |  | ||||||
|     merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"]) |  | ||||||
|     shutil.copyfile(bpecodes_file, merges_file) |  | ||||||
|  |  | ||||||
|     # model config |  | ||||||
|     biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json") |  | ||||||
|  |  | ||||||
|     model_conf = { |  | ||||||
|         "activation_dropout": args["activation_dropout"], |  | ||||||
|         "architectures": ["BioGptForCausalLM"], |  | ||||||
|         "attention_probs_dropout_prob": args["attention_dropout"], |  | ||||||
|         "bos_token_id": 0, |  | ||||||
|         "eos_token_id": 2, |  | ||||||
|         "hidden_act": args["activation_fn"], |  | ||||||
|         "hidden_dropout_prob": args["dropout"], |  | ||||||
|         "hidden_size": args["decoder_embed_dim"], |  | ||||||
|         "initializer_range": 0.02, |  | ||||||
|         "intermediate_size": args["decoder_ffn_embed_dim"], |  | ||||||
|         "layer_norm_eps": 1e-12, |  | ||||||
|         "layerdrop": args["decoder_layerdrop"], |  | ||||||
|         "max_position_embeddings": args["max_target_positions"], |  | ||||||
|         "model_type": "biogpt", |  | ||||||
|         "num_attention_heads": args["decoder_attention_heads"], |  | ||||||
|         "num_hidden_layers": args["decoder_layers"], |  | ||||||
|         "pad_token_id": 1, |  | ||||||
|         "scale_embedding": not args["no_scale_embedding"], |  | ||||||
|         "tie_word_embeddings": args["share_decoder_input_output_embed"], |  | ||||||
|         "vocab_size": src_vocab_size, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     # good hparam defaults to start with |  | ||||||
|  |  | ||||||
|     print(f"Generating {biogpt_model_config_file}") |  | ||||||
|     with open(biogpt_model_config_file, "w", encoding="utf-8") as f: |  | ||||||
|         f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent)) |  | ||||||
|  |  | ||||||
|     # tokenizer config |  | ||||||
|     biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE) |  | ||||||
|  |  | ||||||
|     tokenizer_conf = { |  | ||||||
|         "bos_token": "<s>", |  | ||||||
|         "eos_token": "</s>", |  | ||||||
|         "model_max_length": 1024, |  | ||||||
|         "pad_token": "<pad>", |  | ||||||
|         "special_tokens_map_file": None, |  | ||||||
|         "tokenizer_class": "BioGptTokenizer", |  | ||||||
|         "unk_token": "<unk>", |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     print(f"Generating {biogpt_tokenizer_config_file}") |  | ||||||
|     with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f: |  | ||||||
|         f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent)) |  | ||||||
|  |  | ||||||
|     # model |  | ||||||
|     model_state_dict = chkpt["model"] |  | ||||||
|  |  | ||||||
|     # remove unneeded keys |  | ||||||
|     ignore_keys = [ |  | ||||||
|         "decoder.version", |  | ||||||
|     ] |  | ||||||
|     for k in ignore_keys: |  | ||||||
|         model_state_dict.pop(k, None) |  | ||||||
|  |  | ||||||
|     layer_names = list(model_state_dict.keys()) |  | ||||||
|     for layer_name in layer_names: |  | ||||||
|         if layer_name.endswith("output_projection.weight"): |  | ||||||
|             model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name) |  | ||||||
|         else: |  | ||||||
|             model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name) |  | ||||||
|  |  | ||||||
|     config = BioGptConfig.from_pretrained(pytorch_dump_folder_path) |  | ||||||
|     model_new = BioGptForCausalLM(config) |  | ||||||
|  |  | ||||||
|     # check that it loads ok |  | ||||||
|     model_new.load_state_dict(model_state_dict) |  | ||||||
|  |  | ||||||
|     # save |  | ||||||
|     pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) |  | ||||||
|     print(f"Generating {pytorch_weights_dump_path}") |  | ||||||
|     torch.save(model_state_dict, pytorch_weights_dump_path) |  | ||||||
|  |  | ||||||
|     print("Conversion is done!") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--biogpt_checkpoint_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts," |  | ||||||
|             " bpecodes, etc." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,177 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BiT checkpoints from the timm library.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
| from PIL import Image |  | ||||||
| from timm import create_model |  | ||||||
| from timm.data import resolve_data_config |  | ||||||
| from timm.data.transforms_factory import create_transform |  | ||||||
|  |  | ||||||
| from transformers import BitConfig, BitForImageClassification, BitImageProcessor |  | ||||||
| from transformers.image_utils import PILImageResampling |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_config(model_name): |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     filename = "imagenet-1k-id2label.json" |  | ||||||
|     id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|     label2id = {v: k for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     conv_layer = "std_conv" if "bit" in model_name else False |  | ||||||
|  |  | ||||||
|     # note that when using BiT as backbone for ViT-hybrid checkpoints, |  | ||||||
|     # one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same", |  | ||||||
|     # config.conv_layer = "std_conv_same" |  | ||||||
|     config = BitConfig( |  | ||||||
|         conv_layer=conv_layer, |  | ||||||
|         num_labels=1000, |  | ||||||
|         id2label=id2label, |  | ||||||
|         label2id=label2id, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     return config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     if "stem.conv" in name: |  | ||||||
|         name = name.replace("stem.conv", "bit.embedder.convolution") |  | ||||||
|     if "blocks" in name: |  | ||||||
|         name = name.replace("blocks", "layers") |  | ||||||
|     if "head.fc" in name: |  | ||||||
|         name = name.replace("head.fc", "classifier.1") |  | ||||||
|     if name.startswith("norm"): |  | ||||||
|         name = "bit." + name |  | ||||||
|     if "bit" not in name and "classifier" not in name: |  | ||||||
|         name = "bit.encoder." + name |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our BiT structure. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # define default BiT configuration |  | ||||||
|     config = get_config(model_name) |  | ||||||
|  |  | ||||||
|     # load original model from timm |  | ||||||
|     timm_model = create_model(model_name, pretrained=True) |  | ||||||
|     timm_model.eval() |  | ||||||
|  |  | ||||||
|     # load state_dict of original model |  | ||||||
|     state_dict = timm_model.state_dict() |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         state_dict[rename_key(key)] = val.squeeze() if "head" in key else val |  | ||||||
|  |  | ||||||
|     # load HuggingFace model |  | ||||||
|     model = BitForImageClassification(config) |  | ||||||
|     model.eval() |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|  |  | ||||||
|     # create image processor |  | ||||||
|     transform = create_transform(**resolve_data_config({}, model=timm_model)) |  | ||||||
|     timm_transforms = transform.transforms |  | ||||||
|  |  | ||||||
|     pillow_resamplings = { |  | ||||||
|         "bilinear": PILImageResampling.BILINEAR, |  | ||||||
|         "bicubic": PILImageResampling.BICUBIC, |  | ||||||
|         "nearest": PILImageResampling.NEAREST, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     processor = BitImageProcessor( |  | ||||||
|         do_resize=True, |  | ||||||
|         size={"shortest_edge": timm_transforms[0].size}, |  | ||||||
|         resample=pillow_resamplings[timm_transforms[0].interpolation.value], |  | ||||||
|         do_center_crop=True, |  | ||||||
|         crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]}, |  | ||||||
|         do_normalize=True, |  | ||||||
|         image_mean=timm_transforms[-1].mean.tolist(), |  | ||||||
|         image_std=timm_transforms[-1].std.tolist(), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     image = prepare_img() |  | ||||||
|     timm_pixel_values = transform(image).unsqueeze(0) |  | ||||||
|     pixel_values = processor(image, return_tensors="pt").pixel_values |  | ||||||
|  |  | ||||||
|     # verify pixel values |  | ||||||
|     assert torch.allclose(timm_pixel_values, pixel_values) |  | ||||||
|  |  | ||||||
|     # verify logits |  | ||||||
|     with torch.no_grad(): |  | ||||||
|         outputs = model(pixel_values) |  | ||||||
|         logits = outputs.logits |  | ||||||
|  |  | ||||||
|     print("Logits:", logits[0, :3]) |  | ||||||
|     print("Predicted class:", model.config.id2label[logits.argmax(-1).item()]) |  | ||||||
|     timm_logits = timm_model(pixel_values) |  | ||||||
|     assert timm_logits.shape == outputs.logits.shape |  | ||||||
|     assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) |  | ||||||
|     print("Looks ok!") |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|         print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}") |  | ||||||
|         model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         print(f"Pushing model {model_name} and processor to the hub") |  | ||||||
|         model.push_to_hub(f"ybelkada/{model_name}") |  | ||||||
|         processor.push_to_hub(f"ybelkada/{model_name}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="resnetv2_50x1_bitm", |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the BiT timm model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", |  | ||||||
|         action="store_true", |  | ||||||
|         help="Whether to push the model to the hub.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) |  | ||||||
| @ -1,114 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2020 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert Blenderbot checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| PATTERNS = [ |  | ||||||
|     ["attention", "attn"], |  | ||||||
|     ["encoder_attention", "encoder_attn"], |  | ||||||
|     ["q_lin", "q_proj"], |  | ||||||
|     ["k_lin", "k_proj"], |  | ||||||
|     ["v_lin", "v_proj"], |  | ||||||
|     ["out_lin", "out_proj"], |  | ||||||
|     ["norm_embeddings", "layernorm_embedding"], |  | ||||||
|     ["position_embeddings", "embed_positions"], |  | ||||||
|     ["embeddings", "embed_tokens"], |  | ||||||
|     ["ffn.lin", "fc"], |  | ||||||
| ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_state_dict_key(k): |  | ||||||
|     if k == "embeddings.weight": |  | ||||||
|         return "shared.weight" |  | ||||||
|  |  | ||||||
|     for parlai_name, hf_name in PATTERNS: |  | ||||||
|         k = k.replace(parlai_name, hf_name) |  | ||||||
|  |  | ||||||
|     if k.startswith("encoder"): |  | ||||||
|         k = k.replace(".attn", ".self_attn") |  | ||||||
|         k = k.replace("norm1", "self_attn_layer_norm") |  | ||||||
|         k = k.replace("norm2", "final_layer_norm") |  | ||||||
|     elif k.startswith("decoder"): |  | ||||||
|         k = k.replace("norm1", "self_attn_layer_norm") |  | ||||||
|         k = k.replace("norm2", "encoder_attn_layer_norm") |  | ||||||
|         k = k.replace("norm3", "final_layer_norm") |  | ||||||
|     return k |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_layernorm_keys(sd): |  | ||||||
|     keys = [ |  | ||||||
|         "model.encoder.layernorm_embedding.weight", |  | ||||||
|         "model.encoder.layernorm_embedding.bias", |  | ||||||
|         "model.decoder.layernorm_embedding.weight", |  | ||||||
|         "model.decoder.layernorm_embedding.bias", |  | ||||||
|     ] |  | ||||||
|     for k in keys: |  | ||||||
|         v = sd.pop(k) |  | ||||||
|         new_k = k.replace("layernorm_embedding", "layer_norm") |  | ||||||
|         assert new_k not in sd |  | ||||||
|         sd[new_k] = v |  | ||||||
|  |  | ||||||
|  |  | ||||||
| IGNORE_KEYS = ["START"] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our BERT structure. |  | ||||||
|     """ |  | ||||||
|     model = torch.load(checkpoint_path, map_location="cpu") |  | ||||||
|     sd = model["model"] |  | ||||||
|     cfg = BlenderbotConfig.from_json_file(config_json_path) |  | ||||||
|     m = BlenderbotForConditionalGeneration(cfg) |  | ||||||
|     valid_keys = m.model.state_dict().keys() |  | ||||||
|     failures = [] |  | ||||||
|     mapping = {} |  | ||||||
|     for k, v in sd.items(): |  | ||||||
|         if k in IGNORE_KEYS: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         new_k = rename_state_dict_key(k) |  | ||||||
|         if new_k not in valid_keys: |  | ||||||
|             failures.append([k, new_k]) |  | ||||||
|         else: |  | ||||||
|             mapping[new_k] = v |  | ||||||
|     if cfg.normalize_before:  # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm |  | ||||||
|         rename_layernorm_keys(sd) |  | ||||||
|     m.model.load_state_dict(mapping, strict=True) |  | ||||||
|     m.half() |  | ||||||
|     m.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin") |  | ||||||
|     parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use" |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json) |  | ||||||
| @ -1,191 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| # git clone https://github.com/salesforce/BLIP.git |  | ||||||
| from models.blip import blip_decoder |  | ||||||
| from models.blip_itm import blip_itm |  | ||||||
| from models.blip_vqa import blip_vqa |  | ||||||
| from PIL import Image |  | ||||||
| from torchvision import transforms |  | ||||||
| from torchvision.transforms.functional import InterpolationMode |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     BertTokenizer, |  | ||||||
|     BlipConfig, |  | ||||||
|     BlipForConditionalGeneration, |  | ||||||
|     BlipForImageTextRetrieval, |  | ||||||
|     BlipForQuestionAnswering, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_demo_image(image_size, device): |  | ||||||
|     img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" |  | ||||||
|     raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") |  | ||||||
|  |  | ||||||
|     transform = transforms.Compose( |  | ||||||
|         [ |  | ||||||
|             transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), |  | ||||||
|             transforms.ToTensor(), |  | ||||||
|             transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|     image = transform(raw_image).unsqueeze(0).to(device) |  | ||||||
|     return image |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(key): |  | ||||||
|     if "visual_encoder" in key: |  | ||||||
|         key = re.sub("visual_encoder*", "vision_model.encoder", key) |  | ||||||
|     if "blocks" in key: |  | ||||||
|         key = re.sub(r"blocks", "layers", key) |  | ||||||
|     if "attn" in key: |  | ||||||
|         key = re.sub(r"attn", "self_attn", key) |  | ||||||
|     if "norm1" in key: |  | ||||||
|         key = re.sub(r"norm1", "layer_norm1", key) |  | ||||||
|     if "norm2" in key: |  | ||||||
|         key = re.sub(r"norm2", "layer_norm2", key) |  | ||||||
|     if "encoder.norm" in key: |  | ||||||
|         key = re.sub(r"encoder.norm", "post_layernorm", key) |  | ||||||
|     if "encoder.patch_embed.proj" in key: |  | ||||||
|         key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key) |  | ||||||
|  |  | ||||||
|     if "encoder.pos_embed" in key: |  | ||||||
|         key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key) |  | ||||||
|     if "encoder.cls_token" in key: |  | ||||||
|         key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key) |  | ||||||
|  |  | ||||||
|     if "self_attn" in key: |  | ||||||
|         key = re.sub(r"self_attn.proj", "self_attn.projection", key) |  | ||||||
|  |  | ||||||
|     return key |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to transformers design. |  | ||||||
|     """ |  | ||||||
|     if config_path is not None: |  | ||||||
|         config = BlipConfig.from_pretrained(config_path) |  | ||||||
|     else: |  | ||||||
|         config = BlipConfig(projection_dim=512, text_config={}, vision_config={}) |  | ||||||
|  |  | ||||||
|     hf_model = BlipForConditionalGeneration(config).eval() |  | ||||||
|  |  | ||||||
|     model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" |  | ||||||
|  |  | ||||||
|     pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base") |  | ||||||
|     pt_model = pt_model.eval() |  | ||||||
|  |  | ||||||
|     modified_state_dict = pt_model.state_dict() |  | ||||||
|     for key in modified_state_dict.copy(): |  | ||||||
|         value = modified_state_dict.pop(key) |  | ||||||
|         renamed_key = rename_key(key) |  | ||||||
|         modified_state_dict[renamed_key] = value |  | ||||||
|  |  | ||||||
|     hf_model.load_state_dict(modified_state_dict) |  | ||||||
|  |  | ||||||
|     image_size = 384 |  | ||||||
|     image = load_demo_image(image_size=image_size, device="cpu") |  | ||||||
|     tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") |  | ||||||
|     input_ids = tokenizer(["a picture of"]).input_ids |  | ||||||
|  |  | ||||||
|     out = hf_model.generate(image, input_ids) |  | ||||||
|  |  | ||||||
|     assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] |  | ||||||
|  |  | ||||||
|     out = hf_model.generate(image) |  | ||||||
|  |  | ||||||
|     assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         hf_model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth' |  | ||||||
|     model_url = ( |  | ||||||
|         "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base") |  | ||||||
|     vqa_model.eval() |  | ||||||
|  |  | ||||||
|     modified_state_dict = vqa_model.state_dict() |  | ||||||
|     for key in modified_state_dict.copy(): |  | ||||||
|         value = modified_state_dict.pop(key) |  | ||||||
|         renamed_key = rename_key(key) |  | ||||||
|         modified_state_dict[renamed_key] = value |  | ||||||
|  |  | ||||||
|     hf_vqa_model = BlipForQuestionAnswering(config) |  | ||||||
|  |  | ||||||
|     hf_vqa_model.load_state_dict(modified_state_dict) |  | ||||||
|  |  | ||||||
|     question = ["How many dogs are in this image?"] |  | ||||||
|     question_input_ids = tokenizer(question, return_tensors="pt").input_ids |  | ||||||
|  |  | ||||||
|     answer = hf_vqa_model.generate(question_input_ids, image) |  | ||||||
|     print(tokenizer.decode(answer[0])) |  | ||||||
|  |  | ||||||
|     assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]" |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa") |  | ||||||
|  |  | ||||||
|     model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" |  | ||||||
|  |  | ||||||
|     itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base") |  | ||||||
|     itm_model.eval() |  | ||||||
|  |  | ||||||
|     modified_state_dict = itm_model.state_dict() |  | ||||||
|     for key in modified_state_dict.copy(): |  | ||||||
|         value = modified_state_dict.pop(key) |  | ||||||
|         renamed_key = rename_key(key) |  | ||||||
|         modified_state_dict[renamed_key] = value |  | ||||||
|  |  | ||||||
|     hf_itm_model = BlipForImageTextRetrieval(config) |  | ||||||
|  |  | ||||||
|     question = ["A picture of a woman with a dog sitting in a beach"] |  | ||||||
|     question_input_ids = tokenizer( |  | ||||||
|         question, |  | ||||||
|         return_tensors="pt", |  | ||||||
|         padding="max_length", |  | ||||||
|         truncation=True, |  | ||||||
|         max_length=35, |  | ||||||
|     ).input_ids |  | ||||||
|  |  | ||||||
|     hf_itm_model.load_state_dict(modified_state_dict) |  | ||||||
|     hf_itm_model.eval() |  | ||||||
|  |  | ||||||
|     out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True) |  | ||||||
|     out = hf_itm_model(question_input_ids, image, use_itm_head=False) |  | ||||||
|  |  | ||||||
|     assert out[0].item() == 0.2110687494277954 |  | ||||||
|     assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127 |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path) |  | ||||||
| @ -1,390 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """ |  | ||||||
| Convert BLIP-2 checkpoints from the original repository. |  | ||||||
|  |  | ||||||
| URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2 |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| # pip3 install salesforce-lavis |  | ||||||
| # I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32 |  | ||||||
| # to make sure we can compare both original and HF implementation in float32 |  | ||||||
| from lavis.models import load_model_and_preprocess |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     AutoTokenizer, |  | ||||||
|     BertTokenizer, |  | ||||||
|     Blip2Config, |  | ||||||
|     Blip2ForConditionalGeneration, |  | ||||||
|     Blip2ForImageTextRetrieval, |  | ||||||
|     Blip2Processor, |  | ||||||
|     Blip2QFormerConfig, |  | ||||||
|     Blip2VisionConfig, |  | ||||||
|     BlipImageProcessor, |  | ||||||
|     OPTConfig, |  | ||||||
|     T5Config, |  | ||||||
|     set_seed, |  | ||||||
| ) |  | ||||||
| from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_demo_image(): |  | ||||||
|     url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png" |  | ||||||
|     image = Image.open(requests.get(url, stream=True).raw).convert("RGB") |  | ||||||
|  |  | ||||||
|     return image |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # here we list all keys to be renamed (original name on the left, our name on the right) |  | ||||||
| def create_rename_keys(config, model_name): |  | ||||||
|     rename_keys = [] |  | ||||||
|     # fmt: off |  | ||||||
|  |  | ||||||
|     # vision encoder |  | ||||||
|     rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding")) |  | ||||||
|     rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding")) |  | ||||||
|     rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight")) |  | ||||||
|     rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias")) |  | ||||||
|     rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight")) |  | ||||||
|     rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias")) |  | ||||||
|  |  | ||||||
|     for i in range(config.vision_config.num_hidden_layers): |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",)) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight")) |  | ||||||
|         rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias")) |  | ||||||
|  |  | ||||||
|     # QFormer |  | ||||||
|     rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight")) |  | ||||||
|     rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias")) |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight")) |  | ||||||
|         rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight")) |  | ||||||
|         rename_keys.append(("vision_proj.weight", "vision_projection.weight")) |  | ||||||
|         rename_keys.append(("vision_proj.bias", "vision_projection.bias")) |  | ||||||
|         rename_keys.append(("text_proj.weight", "text_projection.weight")) |  | ||||||
|         rename_keys.append(("text_proj.bias", "text_projection.bias")) |  | ||||||
|  |  | ||||||
|     # fmt: on |  | ||||||
|     return rename_keys |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(dct, old, new): |  | ||||||
|     val = dct.pop(old) |  | ||||||
|     dct[new] = val |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_in_q_v_bias(state_dict, config): |  | ||||||
|     for i in range(config.vision_config.num_hidden_layers): |  | ||||||
|         # read in original q and v biases |  | ||||||
|         q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias") |  | ||||||
|         v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias") |  | ||||||
|  |  | ||||||
|         # next, set bias in the state dict |  | ||||||
|         qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias)) |  | ||||||
|         state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_blip2_config(model_name, eos_token_id): |  | ||||||
|     image_size = 364 if "coco" in model_name else 224 |  | ||||||
|     vision_config = Blip2VisionConfig(image_size=image_size).to_dict() |  | ||||||
|  |  | ||||||
|     # make sure the models have proper bos_token_id and eos_token_id set (important for generation) |  | ||||||
|     # seems like flan-T5 models don't have bos_token_id properly set? |  | ||||||
|     if "opt-2.7b" in model_name: |  | ||||||
|         text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict() |  | ||||||
|     elif "opt-6.7b" in model_name: |  | ||||||
|         text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict() |  | ||||||
|     elif "t5-xl" in model_name: |  | ||||||
|         text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict() |  | ||||||
|     elif "t5-xxl" in model_name: |  | ||||||
|         text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict() |  | ||||||
|     elif "itm" in model_name: |  | ||||||
|         text_config = {} |  | ||||||
|     else: |  | ||||||
|         raise ValueError("Model name not supported") |  | ||||||
|  |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         config = Blip2Config( |  | ||||||
|             vision_config=vision_config, |  | ||||||
|             qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(), |  | ||||||
|         ) |  | ||||||
|     else: |  | ||||||
|         config = Blip2Config(vision_config=vision_config, text_config=text_config) |  | ||||||
|  |  | ||||||
|     return config, image_size |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_blip2_checkpoint( |  | ||||||
|     model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu" |  | ||||||
| ): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to Transformers design. |  | ||||||
|     """ |  | ||||||
|     if "opt" in model_name: |  | ||||||
|         tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") |  | ||||||
|     elif "itm" in model_name: |  | ||||||
|         tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") |  | ||||||
|         tokenizer.add_special_tokens({"bos_token": "[DEC]"}) |  | ||||||
|     else: |  | ||||||
|         tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") |  | ||||||
|  |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         eos_token_id = None |  | ||||||
|     else: |  | ||||||
|         eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0] |  | ||||||
|     config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id) |  | ||||||
|  |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         hf_model = Blip2ForImageTextRetrieval(config).eval() |  | ||||||
|     else: |  | ||||||
|         hf_model = Blip2ForConditionalGeneration(config).eval() |  | ||||||
|  |  | ||||||
|     model_name_to_original = { |  | ||||||
|         "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"), |  | ||||||
|         "blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"), |  | ||||||
|         "blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"), |  | ||||||
|         "blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"), |  | ||||||
|         "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"), |  | ||||||
|         "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"), |  | ||||||
|         "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"), |  | ||||||
|         "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"), |  | ||||||
|         "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"), |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     name, type = model_name_to_original[model_name] |  | ||||||
|  |  | ||||||
|     # load original model |  | ||||||
|     print("Loading original model...") |  | ||||||
|     original_model, vis_processors, _ = load_model_and_preprocess( |  | ||||||
|         name=name, model_type=type, is_eval=True, device=lavis_device |  | ||||||
|     ) |  | ||||||
|     original_model.eval() |  | ||||||
|     print("Done!") |  | ||||||
|  |  | ||||||
|     # update state dict keys |  | ||||||
|     state_dict = original_model.state_dict() |  | ||||||
|     rename_keys = create_rename_keys(config, model_name) |  | ||||||
|     for src, dest in rename_keys: |  | ||||||
|         rename_key(state_dict, src, dest) |  | ||||||
|  |  | ||||||
|     # some keys can be renamed efficiently |  | ||||||
|     for key, val in state_dict.copy().items(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         if key.startswith("Qformer.bert"): |  | ||||||
|             key = key.replace("Qformer.bert", "qformer") |  | ||||||
|         if "attention.self" in key: |  | ||||||
|             key = key.replace("self", "attention") |  | ||||||
|         if "opt_proj" in key: |  | ||||||
|             key = key.replace("opt_proj", "language_projection") |  | ||||||
|         if "t5_proj" in key: |  | ||||||
|             key = key.replace("t5_proj", "language_projection") |  | ||||||
|         if key.startswith("opt"): |  | ||||||
|             key = key.replace("opt", "language") |  | ||||||
|         if key.startswith("t5"): |  | ||||||
|             key = key.replace("t5", "language") |  | ||||||
|         state_dict[key] = val |  | ||||||
|  |  | ||||||
|     # read in qv biases |  | ||||||
|     read_in_q_v_bias(state_dict, config) |  | ||||||
|  |  | ||||||
|     missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) |  | ||||||
|     assert len(missing_keys) == 0 |  | ||||||
|  |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys)) |  | ||||||
|         assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"] |  | ||||||
|     else: |  | ||||||
|         assert unexpected_keys == ["qformer.embeddings.position_ids"] |  | ||||||
|  |  | ||||||
|     image = load_demo_image() |  | ||||||
|     original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device) |  | ||||||
|  |  | ||||||
|     # create processor |  | ||||||
|     image_processor = BlipImageProcessor( |  | ||||||
|         size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD |  | ||||||
|     ) |  | ||||||
|     processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer) |  | ||||||
|     pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device) |  | ||||||
|  |  | ||||||
|     # make sure processor creates exact same pixel values |  | ||||||
|     assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device)) |  | ||||||
|  |  | ||||||
|     original_model.to(lavis_device) |  | ||||||
|     hf_model.to(hf_model_device) |  | ||||||
|  |  | ||||||
|     if "itm" in model_name: |  | ||||||
|         caption = "a large fountain spewing water into the air" |  | ||||||
|         input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device) |  | ||||||
|         attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device) |  | ||||||
|  |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             original_logits = original_model( |  | ||||||
|                 {"image": original_pixel_values, "text_input": [caption]}, match_head="itm" |  | ||||||
|             ) |  | ||||||
|             logits = hf_model( |  | ||||||
|                 pixel_values=pixel_values, |  | ||||||
|                 input_ids=input_ids, |  | ||||||
|                 attention_mask=attention_mask, |  | ||||||
|                 use_image_text_matching_head=True, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         assert original_logits.shape == logits.logits_per_image.shape |  | ||||||
|         print("First values of original logits:", original_logits[0, :3]) |  | ||||||
|         print("First values of HF logits:", logits.logits_per_image[0, :3]) |  | ||||||
|  |  | ||||||
|         # assert values |  | ||||||
|         # cast to same type |  | ||||||
|         target_dtype = logits.logits_per_image.dtype |  | ||||||
|         assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4) |  | ||||||
|  |  | ||||||
|         original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1) |  | ||||||
|         itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1) |  | ||||||
|         assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4) |  | ||||||
|         print("Looks ok!") |  | ||||||
|  |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             original_logits = original_model( |  | ||||||
|                 {"image": original_pixel_values, "text_input": [caption]}, match_head="itc" |  | ||||||
|             ) |  | ||||||
|             logits = hf_model( |  | ||||||
|                 pixel_values=pixel_values, |  | ||||||
|                 input_ids=input_ids, |  | ||||||
|                 attention_mask=attention_mask, |  | ||||||
|                 use_image_text_matching_head=False, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         assert original_logits.shape == logits.logits_per_image.shape |  | ||||||
|         print("First values of original logits:", original_logits[0, :3]) |  | ||||||
|         print("First values of HF logits:", logits.logits_per_image[0, :3]) |  | ||||||
|  |  | ||||||
|         # assert values |  | ||||||
|         # cast to same type |  | ||||||
|         target_dtype = logits.logits_per_image.dtype |  | ||||||
|         assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4) |  | ||||||
|         print("Looks ok!") |  | ||||||
|  |  | ||||||
|     else: |  | ||||||
|         input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device) |  | ||||||
|  |  | ||||||
|         with torch.no_grad(): |  | ||||||
|             if "opt" in model_name: |  | ||||||
|                 original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits |  | ||||||
|                 logits = hf_model(pixel_values, input_ids).logits |  | ||||||
|             else: |  | ||||||
|                 original_logits = original_model( |  | ||||||
|                     {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]} |  | ||||||
|                 ).logits |  | ||||||
|                 labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) |  | ||||||
|                 logits = hf_model(pixel_values, input_ids, labels=labels).logits |  | ||||||
|  |  | ||||||
|         assert original_logits.shape == logits.shape |  | ||||||
|         print("First values of original logits:", original_logits[0, :3, :3]) |  | ||||||
|         print("First values of HF logits:", logits[0, :3, :3]) |  | ||||||
|  |  | ||||||
|         # assert values |  | ||||||
|         assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4) |  | ||||||
|         print("Looks ok!") |  | ||||||
|  |  | ||||||
|         print("Generating a caption...") |  | ||||||
|         prompt = "Question: what object is in this image? Answer:" |  | ||||||
|         input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device) |  | ||||||
|  |  | ||||||
|         set_seed(42) |  | ||||||
|  |  | ||||||
|         original_outputs = original_model.generate( |  | ||||||
|             {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50 |  | ||||||
|         ) |  | ||||||
|         outputs = hf_model.generate( |  | ||||||
|             pixel_values, |  | ||||||
|             input_ids, |  | ||||||
|             do_sample=True, |  | ||||||
|             num_beams=5, |  | ||||||
|             max_length=30, |  | ||||||
|             min_length=1, |  | ||||||
|             top_p=0.9, |  | ||||||
|             repetition_penalty=1.0, |  | ||||||
|             length_penalty=1.0, |  | ||||||
|             temperature=1, |  | ||||||
|         ) |  | ||||||
|         output_text = processor.batch_decode(outputs, skip_special_tokens=True) |  | ||||||
|         output_text = [text.strip() for text in output_text] |  | ||||||
|         print("Original generation:", original_outputs) |  | ||||||
|         print("HF generation:", output_text) |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         hf_model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         processor.push_to_hub(f"nielsr/{model_name}") |  | ||||||
|         hf_model.push_to_hub(f"nielsr/{model_name}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     choices = [ |  | ||||||
|         "blip2-opt-2.7b", |  | ||||||
|         "blip2-opt-6.7b", |  | ||||||
|         "blip2-opt-2.7b-coco", |  | ||||||
|         "blip2-opt-6.7b-coco", |  | ||||||
|         "blip2-flan-t5-xl", |  | ||||||
|         "blip2-flan-t5-xl-coco", |  | ||||||
|         "blip2-flan-t5-xxl", |  | ||||||
|         "blip2-itm-vit-g", |  | ||||||
|         "blip2-itm-vit-g-coco", |  | ||||||
|     ] |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="blip2-opt-2.7b", |  | ||||||
|         choices=choices, |  | ||||||
|         type=str, |  | ||||||
|         help="Path to hf config.json of model to convert", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", |  | ||||||
|         action="store_true", |  | ||||||
|         help="Whether to push the model and processor to the hub after converting", |  | ||||||
|     ) |  | ||||||
|     # note: this script is tested on 2 GPUs, as models are compared in float32, |  | ||||||
|     # which requires quite some memory. Hence loading both on a |  | ||||||
|     # separate device is the easiest to compare |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_blip2_checkpoint( |  | ||||||
|         args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device |  | ||||||
|     ) |  | ||||||
| @ -1239,6 +1239,9 @@ class Blip2TextEmbeddings(nn.Module): | |||||||
|                 embeddings += position_embeddings |                 embeddings += position_embeddings | ||||||
|  |  | ||||||
|             if query_embeds is not None: |             if query_embeds is not None: | ||||||
|  |                 # `query_embeds` are kept in fp32 when we use it with Qformer | ||||||
|  |                 if query_embeds.dtype != embeddings.dtype: | ||||||
|  |                     query_embeds = query_embeds.to(embeddings.dtype) | ||||||
|                 embeddings = torch.cat((query_embeds, embeddings), dim=1) |                 embeddings = torch.cat((query_embeds, embeddings), dim=1) | ||||||
|         else: |         else: | ||||||
|             embeddings = query_embeds |             embeddings = query_embeds | ||||||
| @ -1386,6 +1389,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel): | |||||||
|         # If a 2D or 3D attention mask is provided for the cross-attention |         # If a 2D or 3D attention mask is provided for the cross-attention | ||||||
|         # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |         # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] | ||||||
|         if encoder_hidden_states is not None: |         if encoder_hidden_states is not None: | ||||||
|  |             # Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already | ||||||
|  |             if encoder_hidden_states.dtype != query_embeds.dtype: | ||||||
|  |                 encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype) | ||||||
|  |  | ||||||
|             if isinstance(encoder_hidden_states, list): |             if isinstance(encoder_hidden_states, list): | ||||||
|                 encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() |                 encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() | ||||||
|             else: |             else: | ||||||
| @ -1448,6 +1455,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel): | |||||||
| class Blip2Model(Blip2PreTrainedModel): | class Blip2Model(Blip2PreTrainedModel): | ||||||
|     config_class = Blip2Config |     config_class = Blip2Config | ||||||
|     main_input_name = "pixel_values" |     main_input_name = "pixel_values" | ||||||
|  |     _keep_in_fp32_modules = ["query_tokens", "qformer"] | ||||||
|  |  | ||||||
|     def __init__(self, config: Blip2Config): |     def __init__(self, config: Blip2Config): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
| @ -1728,6 +1736,10 @@ class Blip2Model(Blip2PreTrainedModel): | |||||||
|         ) |         ) | ||||||
|         query_output = query_outputs[0] |         query_output = query_outputs[0] | ||||||
|  |  | ||||||
|  |         # Qformer is kept in fp32, we downcast the output back if needed | ||||||
|  |         if query_output.dtype != image_embeds.dtype: | ||||||
|  |             query_output = query_output.to(image_embeds.dtype) | ||||||
|  |  | ||||||
|         # step 3: use the language model, conditioned on the query outputs and the prompt |         # step 3: use the language model, conditioned on the query outputs and the prompt | ||||||
|         language_model_inputs = self.language_projection(query_output) |         language_model_inputs = self.language_projection(query_output) | ||||||
|         language_model_attention_mask = torch.ones( |         language_model_attention_mask = torch.ones( | ||||||
| @ -1799,7 +1811,7 @@ class Blip2Model(Blip2PreTrainedModel): | |||||||
| ) | ) | ||||||
| class Blip2TextModelWithProjection(Blip2PreTrainedModel): | class Blip2TextModelWithProjection(Blip2PreTrainedModel): | ||||||
|     supports_gradient_checkpointing = False |     supports_gradient_checkpointing = False | ||||||
|     _keep_in_fp32_modules = ["query_tokens"] |     _keep_in_fp32_modules = ["query_tokens", "qformer"] | ||||||
|  |  | ||||||
|     def __init__(self, config: Blip2Config): |     def __init__(self, config: Blip2Config): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
| @ -1898,7 +1910,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): | |||||||
| ) | ) | ||||||
| class Blip2VisionModelWithProjection(Blip2PreTrainedModel): | class Blip2VisionModelWithProjection(Blip2PreTrainedModel): | ||||||
|     main_input_name = "pixel_values" |     main_input_name = "pixel_values" | ||||||
|     _keep_in_fp32_modules = ["query_tokens"] |     _keep_in_fp32_modules = ["query_tokens", "qformer"] | ||||||
|  |  | ||||||
|     def __init__(self, config: Blip2Config): |     def __init__(self, config: Blip2Config): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
| @ -2019,6 +2031,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): | |||||||
|     _supports_cache_class = True |     _supports_cache_class = True | ||||||
|     _supports_static_cache = True |     _supports_static_cache = True | ||||||
|     _supports_quantized_cache = False  # not all LM bacbones support (e.g. T5) |     _supports_quantized_cache = False  # not all LM bacbones support (e.g. T5) | ||||||
|  |     _keep_in_fp32_modules = ["query_tokens", "qformer"] | ||||||
|  |  | ||||||
|     def __init__(self, config: Blip2Config): |     def __init__(self, config: Blip2Config): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
| @ -2191,6 +2204,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): | |||||||
|         ) |         ) | ||||||
|         query_output = query_outputs[0] |         query_output = query_outputs[0] | ||||||
|  |  | ||||||
|  |         # Qformer is kept in fp32, we downcast the output back if needed | ||||||
|  |         if query_output.dtype != image_embeds.dtype: | ||||||
|  |             query_output = query_output.to(image_embeds.dtype) | ||||||
|  |  | ||||||
|         # step 3: use the language model, conditioned on the query outputs and the prompt |         # step 3: use the language model, conditioned on the query outputs and the prompt | ||||||
|         language_model_inputs = self.language_projection(query_output) |         language_model_inputs = self.language_projection(query_output) | ||||||
|         language_model_attention_mask = torch.ones( |         language_model_attention_mask = torch.ones( | ||||||
| @ -2312,6 +2329,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): | |||||||
|         ) |         ) | ||||||
|         query_output = query_outputs.last_hidden_state |         query_output = query_outputs.last_hidden_state | ||||||
|  |  | ||||||
|  |         # Qformer is kept in fp32, we downcast the output back if needed | ||||||
|  |         if query_output.dtype != image_embeds.dtype: | ||||||
|  |             query_output = query_output.to(image_embeds.dtype) | ||||||
|  |  | ||||||
|         language_model_inputs = self.language_projection(query_output) |         language_model_inputs = self.language_projection(query_output) | ||||||
|         language_attention_mask = torch.ones( |         language_attention_mask = torch.ones( | ||||||
|             language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device |             language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device | ||||||
| @ -2371,7 +2392,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): | |||||||
| ) | ) | ||||||
| class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): | class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): | ||||||
|     main_input_name = "pixel_values" |     main_input_name = "pixel_values" | ||||||
|     _keep_in_fp32_modules = ["query_tokens"] |     _keep_in_fp32_modules = ["query_tokens", "qformer"] | ||||||
|  |  | ||||||
|     def __init__(self, config: Blip2Config): |     def __init__(self, config: Blip2Config): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
|  | |||||||
| @ -1,254 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert BigScience BLOOM checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| import os |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BloomConfig, BloomModel |  | ||||||
| from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
| WEIGHTS_TO_AVERAGE_ENDSWITH = [ |  | ||||||
|     "word_embeddings_layernorm.weight", |  | ||||||
|     "word_embeddings_layernorm.bias", |  | ||||||
|     "input_layernorm.weight", |  | ||||||
|     "input_layernorm.bias", |  | ||||||
|     "post_attention_layernorm.weight", |  | ||||||
|     "post_attention_layernorm.bias", |  | ||||||
|     "self_attention.dense.bias", |  | ||||||
|     "mlp.dense_4h_to_h.bias", |  | ||||||
|     "ln_f.weight", |  | ||||||
|     "ln_f.bias", |  | ||||||
| ] |  | ||||||
|  |  | ||||||
| WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [ |  | ||||||
|     "mlp.dense_4h_to_h.weight", |  | ||||||
|     "self_attention.dense.weight", |  | ||||||
| ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def layer_name_mapping(key, file): |  | ||||||
|     """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only""" |  | ||||||
|     # Handle first and last layers |  | ||||||
|     layer_rename_map = { |  | ||||||
|         "word_embeddings.weight": "word_embeddings.weight", |  | ||||||
|         "word_embeddings.norm.weight": "word_embeddings_layernorm.weight", |  | ||||||
|         "word_embeddings.norm.bias": "word_embeddings_layernorm.bias", |  | ||||||
|         "weight": "ln_f.weight", |  | ||||||
|         "bias": "ln_f.bias", |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if key in layer_rename_map: |  | ||||||
|         return layer_rename_map[key] |  | ||||||
|  |  | ||||||
|     # Handle transformer blocks |  | ||||||
|     layer_number = int(re.match(r".*layer_(\d*).*", file)[1]) |  | ||||||
|     layer_number -= 3 |  | ||||||
|     return f"h.{layer_number}." + key |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_dtype_size(dtype): |  | ||||||
|     if dtype == torch.bool: |  | ||||||
|         return 1 / 8 |  | ||||||
|     bit_search = re.search(r"[^\d](\d+)$", str(dtype)) |  | ||||||
|     if bit_search is None: |  | ||||||
|         raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") |  | ||||||
|     bit_size = int(bit_search.groups()[0]) |  | ||||||
|     return bit_size // 8 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_bloom_checkpoint_to_pytorch( |  | ||||||
|     bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp |  | ||||||
| ): |  | ||||||
|     # Construct model |  | ||||||
|     if bloom_config_file == "": |  | ||||||
|         config = BloomConfig() |  | ||||||
|     else: |  | ||||||
|         config = BloomConfig.from_json_file(bloom_config_file) |  | ||||||
|  |  | ||||||
|     if shard_model: |  | ||||||
|         file_names = os.listdir(bloom_checkpoint_path) |  | ||||||
|         file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names)) |  | ||||||
|  |  | ||||||
|         index_dict = {"weight_map": {}, "metadata": {}} |  | ||||||
|         total_size = 0 |  | ||||||
|  |  | ||||||
|         missing_keys = None |  | ||||||
|  |  | ||||||
|         config = BloomConfig() |  | ||||||
|  |  | ||||||
|         for j, file in enumerate(file_names): |  | ||||||
|             print("Processing file: {}".format(file)) |  | ||||||
|             tensors = None |  | ||||||
|  |  | ||||||
|             for i in range(pretraining_tp): |  | ||||||
|                 # load all TP files |  | ||||||
|                 f_name = file.replace("model_00", f"model_0{i}") |  | ||||||
|                 temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu") |  | ||||||
|  |  | ||||||
|                 # Rename keys in the transformers names |  | ||||||
|                 keys = list(temp.keys()) |  | ||||||
|                 for key in keys: |  | ||||||
|                     temp[layer_name_mapping(key, file)] = temp.pop(key) |  | ||||||
|  |  | ||||||
|                 if tensors is None: |  | ||||||
|                     tensors = temp |  | ||||||
|                 else: |  | ||||||
|                     for key in tensors.keys(): |  | ||||||
|                         if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): |  | ||||||
|                             # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425) |  | ||||||
|                             tensors[key] += temp[key] |  | ||||||
|                         else: |  | ||||||
|                             # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel |  | ||||||
|                             cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 |  | ||||||
|                             # We concatenate these weights accross TP ranks |  | ||||||
|                             tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim) |  | ||||||
|  |  | ||||||
|             # Divide by the number of TP the weights we want to average |  | ||||||
|             for key in tensors.keys(): |  | ||||||
|                 if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): |  | ||||||
|                     tensors[key] = tensors[key] / pretraining_tp |  | ||||||
|             torch.save( |  | ||||||
|                 tensors, |  | ||||||
|                 os.path.join( |  | ||||||
|                     pytorch_dump_folder_path, |  | ||||||
|                     "pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)), |  | ||||||
|                 ), |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             for key in tensors.keys(): |  | ||||||
|                 value = tensors[key] |  | ||||||
|                 total_size += value.numel() * get_dtype_size(value.dtype) |  | ||||||
|                 if key not in index_dict["weight_map"]: |  | ||||||
|                     index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format( |  | ||||||
|                         str(j + 1).zfill(5), str(len(file_names)).zfill(5) |  | ||||||
|                     ) |  | ||||||
|  |  | ||||||
|         config = BloomConfig() |  | ||||||
|         pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME |  | ||||||
|         index_dict["metadata"]["total_size"] = total_size |  | ||||||
|         with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: |  | ||||||
|             f.write(config.to_json_string()) |  | ||||||
|         with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f: |  | ||||||
|             json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n" |  | ||||||
|             f.write(json_config) |  | ||||||
|     else: |  | ||||||
|         model = BloomModel(config) |  | ||||||
|  |  | ||||||
|         file_names = os.listdir(bloom_checkpoint_path) |  | ||||||
|         file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names)) |  | ||||||
|  |  | ||||||
|         missing_keys = None |  | ||||||
|         for i, file in enumerate(file_names): |  | ||||||
|             tensors = None |  | ||||||
|             for i in range(pretraining_tp): |  | ||||||
|                 # load all TP files |  | ||||||
|                 f_name = file.replace("model_00", f"model_0{i}") |  | ||||||
|                 temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu") |  | ||||||
|  |  | ||||||
|                 # Rename keys in the transformers names |  | ||||||
|                 keys = list(temp.keys()) |  | ||||||
|                 for key in keys: |  | ||||||
|                     temp[layer_name_mapping(key, file)] = temp.pop(key) |  | ||||||
|  |  | ||||||
|                 if tensors is None: |  | ||||||
|                     tensors = temp |  | ||||||
|                 else: |  | ||||||
|                     for key in tensors.keys(): |  | ||||||
|                         # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425) |  | ||||||
|                         if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): |  | ||||||
|                             tensors[key] += temp[key] |  | ||||||
|                         else: |  | ||||||
|                             # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel |  | ||||||
|                             cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0 |  | ||||||
|                             # We concatenate these weights accross TP ranks |  | ||||||
|                             tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim) |  | ||||||
|  |  | ||||||
|             # Divide by the number of TP the weights we want to average |  | ||||||
|             for key in tensors.keys(): |  | ||||||
|                 if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH): |  | ||||||
|                     tensors[key] = tensors[key] / pretraining_tp |  | ||||||
|  |  | ||||||
|             other_keys = model.load_state_dict(tensors, strict=False) |  | ||||||
|             assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected" |  | ||||||
|             if missing_keys is None: |  | ||||||
|                 missing_keys = set(other_keys.missing_keys) |  | ||||||
|             else: |  | ||||||
|                 missing_keys = missing_keys.intersection(set(other_keys.missing_keys)) |  | ||||||
|  |  | ||||||
|         assert not missing_keys, f"The keys {missing_keys} are missing" |  | ||||||
|  |  | ||||||
|         # Save pytorch-model |  | ||||||
|         os.makedirs(pytorch_dump_folder_path, exist_ok=True) |  | ||||||
|         pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME |  | ||||||
|         pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME |  | ||||||
|         print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}") |  | ||||||
|         if config.torch_dtype is not None: |  | ||||||
|             model = model.to(config.torch_dtype) |  | ||||||
|         torch.save(model.state_dict(), pytorch_weights_dump_path) |  | ||||||
|         print(f"Save configuration file to {pytorch_config_dump_path}") |  | ||||||
|         with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: |  | ||||||
|             f.write(config.to_json_string()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--bloom_checkpoint_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to the Megatron-LM checkpoint path.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--bloom_config_file", |  | ||||||
|         default="", |  | ||||||
|         type=str, |  | ||||||
|         help=( |  | ||||||
|             "An optional config json file corresponding to the pre-trained model. \n" |  | ||||||
|             "This specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--shard_model", |  | ||||||
|         action="store_true", |  | ||||||
|         help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pretraining_tp", |  | ||||||
|         default=4, |  | ||||||
|         type=int, |  | ||||||
|         help="Pretraining TP rank that has been used when training the model in Megatron-LM \n", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_bloom_checkpoint_to_pytorch( |  | ||||||
|         args.bloom_checkpoint_path, |  | ||||||
|         args.bloom_config_file, |  | ||||||
|         args.pytorch_dump_folder_path, |  | ||||||
|         args.shard_model, |  | ||||||
|         args.pretraining_tp, |  | ||||||
|     ) |  | ||||||
| @ -1,145 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert Bros checkpoints.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import bros  # original repo |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import BrosConfig, BrosModel, BrosProcessor |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_configs(model_name): |  | ||||||
|     bros_config = BrosConfig.from_pretrained(model_name) |  | ||||||
|     return bros_config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def remove_ignore_keys_(state_dict): |  | ||||||
|     ignore_keys = [ |  | ||||||
|         "embeddings.bbox_sinusoid_emb.inv_freq", |  | ||||||
|     ] |  | ||||||
|     for k in ignore_keys: |  | ||||||
|         state_dict.pop(k, None) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     if name == "embeddings.bbox_projection.weight": |  | ||||||
|         name = "bbox_embeddings.bbox_projection.weight" |  | ||||||
|  |  | ||||||
|     if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq": |  | ||||||
|         name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq" |  | ||||||
|  |  | ||||||
|     if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq": |  | ||||||
|         name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq" |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_state_dict(orig_state_dict, model): |  | ||||||
|     # rename keys |  | ||||||
|     for key in orig_state_dict.copy().keys(): |  | ||||||
|         val = orig_state_dict.pop(key) |  | ||||||
|         orig_state_dict[rename_key(key)] = val |  | ||||||
|  |  | ||||||
|     # remove ignore keys |  | ||||||
|     remove_ignore_keys_(orig_state_dict) |  | ||||||
|  |  | ||||||
|     return orig_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): |  | ||||||
|     # load original model |  | ||||||
|     original_model = bros.BrosModel.from_pretrained(model_name).eval() |  | ||||||
|  |  | ||||||
|     # load HuggingFace Model |  | ||||||
|     bros_config = get_configs(model_name) |  | ||||||
|     model = BrosModel.from_pretrained(model_name, config=bros_config) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     state_dict = original_model.state_dict() |  | ||||||
|     new_state_dict = convert_state_dict(state_dict, model) |  | ||||||
|     model.load_state_dict(new_state_dict) |  | ||||||
|  |  | ||||||
|     # verify results |  | ||||||
|  |  | ||||||
|     # original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape |  | ||||||
|     bbox = torch.tensor( |  | ||||||
|         [ |  | ||||||
|             [ |  | ||||||
|                 [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], |  | ||||||
|                 [0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850], |  | ||||||
|                 [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850], |  | ||||||
|                 [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850], |  | ||||||
|                 [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000], |  | ||||||
|                 [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000], |  | ||||||
|                 [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], |  | ||||||
|             ] |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     processor = BrosProcessor.from_pretrained(model_name) |  | ||||||
|  |  | ||||||
|     encoding = processor("His name is Rocco.", return_tensors="pt") |  | ||||||
|     encoding["bbox"] = bbox |  | ||||||
|  |  | ||||||
|     original_hidden_states = original_model(**encoding).last_hidden_state |  | ||||||
|     # pixel_values = processor(image, return_tensors="pt").pixel_values |  | ||||||
|  |  | ||||||
|     last_hidden_states = model(**encoding).last_hidden_state |  | ||||||
|  |  | ||||||
|     assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4) |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         print(f"Saving model and processor to {pytorch_dump_folder_path}") |  | ||||||
|         model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model") |  | ||||||
|         processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|  |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="jinho8345/bros-base-uncased", |  | ||||||
|         required=False, |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the original model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default=None, |  | ||||||
|         required=False, |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the output PyTorch model directory.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", |  | ||||||
|         action="store_true", |  | ||||||
|         help="Whether or not to push the converted model and processor to the 🤗 hub.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) |  | ||||||
| @ -1,59 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2018 The T5 authors and HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert T5 checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): |  | ||||||
|     # Initialise PyTorch model |  | ||||||
|     config = T5Config.from_json_file(config_file) |  | ||||||
|     print(f"Building PyTorch model from configuration: {config}") |  | ||||||
|     model = T5ForConditionalGeneration(config) |  | ||||||
|  |  | ||||||
|     # Load weights from tf checkpoint |  | ||||||
|     load_tf_weights_in_t5(model, config, tf_checkpoint_path) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model |  | ||||||
|     print(f"Save PyTorch model to {pytorch_dump_path}") |  | ||||||
|     model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--config_file", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,65 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert CANINE checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path): |  | ||||||
|     # Initialize PyTorch model |  | ||||||
|     config = CanineConfig() |  | ||||||
|     model = CanineModel(config) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     print(f"Building PyTorch model from configuration: {config}") |  | ||||||
|  |  | ||||||
|     # Load weights from tf checkpoint |  | ||||||
|     load_tf_weights_in_canine(model, config, tf_checkpoint_path) |  | ||||||
|  |  | ||||||
|     # Save pytorch-model (weights and configuration) |  | ||||||
|     print(f"Save PyTorch model to {pytorch_dump_path}") |  | ||||||
|     model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|     # Save tokenizer files |  | ||||||
|     tokenizer = CanineTokenizer() |  | ||||||
|     print(f"Save tokenizer files to {pytorch_dump_path}") |  | ||||||
|     tokenizer.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to the TensorFlow checkpoint. Should end with model.ckpt", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to a folder where the PyTorch model will be placed.", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path) |  | ||||||
| @ -1,476 +0,0 @@ | |||||||
| # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| import argparse |  | ||||||
| import gc |  | ||||||
| import json |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| import yaml |  | ||||||
| from accelerate import init_empty_weights |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     ChameleonConfig, |  | ||||||
|     ChameleonForConditionalGeneration, |  | ||||||
|     ChameleonImageProcessor, |  | ||||||
|     ChameleonProcessor, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| try: |  | ||||||
|     from transformers import LlamaTokenizerFast |  | ||||||
| except ImportError: |  | ||||||
|     raise ValueError( |  | ||||||
|         "Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! " |  | ||||||
|         "Update your `tokenizers` library and re-run the tokenizer conversion." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
| """ |  | ||||||
| Sample usage: |  | ||||||
|  |  | ||||||
| ``` |  | ||||||
| python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \ |  | ||||||
|     --input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| Thereafter, models can be loaded via: |  | ||||||
|  |  | ||||||
| ```py |  | ||||||
| from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast |  | ||||||
|  |  | ||||||
| model = ChameleonForConditionalGeneration.from_pretrained("/output/path") |  | ||||||
| tokenizer = LlamaTokenizerFast.from_pretrained("/output/path") |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions |  | ||||||
| come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| NUM_SHARDS = { |  | ||||||
|     "7B": 1, |  | ||||||
|     "30B": 4, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| VOCAB_SIZE = 65536 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): |  | ||||||
|     return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_json(path): |  | ||||||
|     with open(path, "r") as f: |  | ||||||
|         return json.load(f) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def write_json(text, path): |  | ||||||
|     with open(path, "w") as f: |  | ||||||
|         json.dump(text, f) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def write_model(model_path, input_base_path, model_size, chameleon_version=1): |  | ||||||
|     os.makedirs(model_path, exist_ok=True) |  | ||||||
|     input_model_path = os.path.join(input_base_path, "models", model_size.lower()) |  | ||||||
|     params_path = os.path.join(input_model_path, "params.json") |  | ||||||
|     consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json") |  | ||||||
|  |  | ||||||
|     params = read_json(params_path) |  | ||||||
|     if os.path.isfile(consolidate_params_path): |  | ||||||
|         params = {**params, **read_json(consolidate_params_path)} |  | ||||||
|     num_shards = NUM_SHARDS[model_size] |  | ||||||
|     model_parallel_size = params["model_parallel_size"] |  | ||||||
|     params = params.get("model", params) |  | ||||||
|     n_layers = params["n_layers"] |  | ||||||
|     n_heads = params["n_heads"] |  | ||||||
|     n_heads_per_shard = n_heads // num_shards |  | ||||||
|     dim = params["dim"] |  | ||||||
|     dims_per_head = dim // n_heads |  | ||||||
|     base = params.get("rope_theta", 10000.0) |  | ||||||
|     swin_norm = params["swin_norm"] |  | ||||||
|     if base > 10000.0: |  | ||||||
|         max_position_embeddings = 16384 |  | ||||||
|     else: |  | ||||||
|         # Depending on the Chameleon version, the default max_position_embeddings has different values. |  | ||||||
|         if chameleon_version == 1: |  | ||||||
|             max_position_embeddings = 4096 |  | ||||||
|         else: |  | ||||||
|             raise NotImplementedError( |  | ||||||
|                 f"Version {chameleon_version} of chameleon is not supported yet. " |  | ||||||
|                 "Current supported versions of chameleon are [1]." |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     if params.get("n_kv_heads", None) is not None: |  | ||||||
|         num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA |  | ||||||
|         num_local_key_value_heads = n_heads_per_shard // num_key_value_heads |  | ||||||
|         key_value_dim = dim // num_key_value_heads |  | ||||||
|     else:  # compatibility with other checkpoints |  | ||||||
|         num_key_value_heads = n_heads |  | ||||||
|         num_local_key_value_heads = n_heads_per_shard |  | ||||||
|         key_value_dim = dim |  | ||||||
|  |  | ||||||
|     print(f"Fetching all parameters from the checkpoint at {input_model_path}.") |  | ||||||
|     # Load weights |  | ||||||
|     if num_shards == 1: |  | ||||||
|         # Not sharded |  | ||||||
|         # (The sharded implementation would also work, but this is simpler.) |  | ||||||
|         loaded = None |  | ||||||
|         for possible_name in ["consolidated.pth", "consolidated.00.pth"]: |  | ||||||
|             possible_path = os.path.join(input_model_path, possible_name) |  | ||||||
|             if os.path.exists(possible_path): |  | ||||||
|                 loaded = torch.load(possible_path, map_location="cpu") |  | ||||||
|                 break |  | ||||||
|         assert loaded is not None |  | ||||||
|     else: |  | ||||||
|         # Sharded |  | ||||||
|         loaded = [ |  | ||||||
|             torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu") |  | ||||||
|             for i in range(num_shards) |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     # permute for sliced rotary |  | ||||||
|     def permute(w, n_heads, dim1=dim, dim2=dim): |  | ||||||
|         return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) |  | ||||||
|  |  | ||||||
|     # Load weights to the state dict |  | ||||||
|     state_dict = {} |  | ||||||
|     for layer_i in range(n_layers): |  | ||||||
|         if num_shards == 1: |  | ||||||
|             # Unsharded |  | ||||||
|             state_dict.update( |  | ||||||
|                 { |  | ||||||
|                     f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( |  | ||||||
|                         loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads |  | ||||||
|                     ), |  | ||||||
|                     f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( |  | ||||||
|                         loaded[f"layers.{layer_i}.attention.wk.weight"], |  | ||||||
|                         n_heads=num_key_value_heads, |  | ||||||
|                         dim1=key_value_dim, |  | ||||||
|                     ), |  | ||||||
|                     f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], |  | ||||||
|                     f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], |  | ||||||
|                     f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], |  | ||||||
|                     f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], |  | ||||||
|                     f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], |  | ||||||
|                     f"model.layers.{layer_i}.input_layernorm.weight": loaded[ |  | ||||||
|                         f"layers.{layer_i}.attention_norm.weight" |  | ||||||
|                     ], |  | ||||||
|                     f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ |  | ||||||
|                         f"layers.{layer_i}.ffn_norm.weight" |  | ||||||
|                     ], |  | ||||||
|                 } |  | ||||||
|             ) |  | ||||||
|             # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = ( |  | ||||||
|                 loaded[f"layers.{layer_i}.attention.q_normalization.weight"] |  | ||||||
|                 .view(dims_per_head // 2, 2) |  | ||||||
|                 .t() |  | ||||||
|                 .reshape(1, -1) |  | ||||||
|                 .repeat_interleave(n_heads, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = ( |  | ||||||
|                 loaded[f"layers.{layer_i}.attention.q_normalization.bias"] |  | ||||||
|                 .view(dims_per_head // 2, 2) |  | ||||||
|                 .t() |  | ||||||
|                 .reshape(1, -1) |  | ||||||
|                 .repeat_interleave(n_heads, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = ( |  | ||||||
|                 loaded[f"layers.{layer_i}.attention.k_normalization.weight"] |  | ||||||
|                 .view(dims_per_head // 2, 2) |  | ||||||
|                 .t() |  | ||||||
|                 .reshape(1, -1) |  | ||||||
|                 .repeat_interleave(num_key_value_heads, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = ( |  | ||||||
|                 loaded[f"layers.{layer_i}.attention.k_normalization.bias"] |  | ||||||
|                 .view(dims_per_head // 2, 2) |  | ||||||
|                 .t() |  | ||||||
|                 .reshape(1, -1) |  | ||||||
|                 .repeat_interleave(num_key_value_heads, 0) |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         else: |  | ||||||
|             # Sharded |  | ||||||
|             state_dict.update( |  | ||||||
|                 { |  | ||||||
|                     f"model.layers.{layer_i}.input_layernorm.weight": torch.stack( |  | ||||||
|                         [l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded] |  | ||||||
|                     ).mean(dim=0), |  | ||||||
|                     f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack( |  | ||||||
|                         [l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded] |  | ||||||
|                     ).mean(dim=0), |  | ||||||
|                 } |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( |  | ||||||
|                 torch.cat( |  | ||||||
|                     [ |  | ||||||
|                         loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) |  | ||||||
|                         for i in range(num_shards) |  | ||||||
|                     ], |  | ||||||
|                     dim=0, |  | ||||||
|                 ).reshape(dim, dim), |  | ||||||
|                 n_heads=n_heads, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( |  | ||||||
|                 torch.cat( |  | ||||||
|                     [ |  | ||||||
|                         loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( |  | ||||||
|                             num_local_key_value_heads, dims_per_head, dim |  | ||||||
|                         ) |  | ||||||
|                         for i in range(num_shards) |  | ||||||
|                     ], |  | ||||||
|                     dim=0, |  | ||||||
|                 ).reshape(key_value_dim, dim), |  | ||||||
|                 n_heads=num_key_value_heads, |  | ||||||
|                 dim1=key_value_dim, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = ( |  | ||||||
|                 torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded]) |  | ||||||
|                 .view(num_shards, dims_per_head // 2, 2) |  | ||||||
|                 .transpose(1, 2) |  | ||||||
|                 .reshape(num_shards, -1) |  | ||||||
|                 .repeat_interleave(n_heads // num_shards, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = ( |  | ||||||
|                 torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded]) |  | ||||||
|                 .view(num_shards, dims_per_head // 2, 2) |  | ||||||
|                 .transpose(1, 2) |  | ||||||
|                 .reshape(num_shards, -1) |  | ||||||
|                 .repeat_interleave(n_heads // num_shards, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = ( |  | ||||||
|                 torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded]) |  | ||||||
|                 .view(num_shards, dims_per_head // 2, 2) |  | ||||||
|                 .transpose(1, 2) |  | ||||||
|                 .reshape(num_shards, -1) |  | ||||||
|                 .repeat_interleave(num_key_value_heads // num_shards, 0) |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = ( |  | ||||||
|                 torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded]) |  | ||||||
|                 .view(num_shards, dims_per_head // 2, 2) |  | ||||||
|                 .transpose(1, 2) |  | ||||||
|                 .reshape(num_shards, -1) |  | ||||||
|                 .repeat_interleave(num_key_value_heads // num_shards, 0) |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( |  | ||||||
|                 [ |  | ||||||
|                     loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( |  | ||||||
|                         num_local_key_value_heads, dims_per_head, dim |  | ||||||
|                     ) |  | ||||||
|                     for i in range(num_shards) |  | ||||||
|                 ], |  | ||||||
|                 dim=0, |  | ||||||
|             ).reshape(key_value_dim, dim) |  | ||||||
|  |  | ||||||
|             state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( |  | ||||||
|                 [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( |  | ||||||
|                 [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( |  | ||||||
|                 [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 |  | ||||||
|             ) |  | ||||||
|             state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( |  | ||||||
|                 [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     if num_shards == 1: |  | ||||||
|         # Unsharded |  | ||||||
|         state_dict.update( |  | ||||||
|             { |  | ||||||
|                 "model.embed_tokens.weight": loaded["tok_embeddings.weight"], |  | ||||||
|                 "model.norm.weight": loaded["norm.weight"], |  | ||||||
|                 "lm_head.weight": loaded["output.weight"], |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|     else: |  | ||||||
|         state_dict.update( |  | ||||||
|             { |  | ||||||
|                 "model.embed_tokens.weight": torch.cat( |  | ||||||
|                     [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 |  | ||||||
|                 ), |  | ||||||
|                 "model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0), |  | ||||||
|                 "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     # Load VQGAN weights |  | ||||||
|     vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt") |  | ||||||
|     vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"] |  | ||||||
|     for k, v in vqgan_state_dict.items(): |  | ||||||
|         if "decoder" in k: |  | ||||||
|             continue  # we dont do image generation yet |  | ||||||
|         state_dict[f"model.vqmodel.{k}"] = v |  | ||||||
|  |  | ||||||
|     # Write configs |  | ||||||
|     ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 |  | ||||||
|     multiple_of = params["multiple_of"] if "multiple_of" in params else 256 |  | ||||||
|  |  | ||||||
|     with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file: |  | ||||||
|         tokenizer_config = json.load(tokenizer_file) |  | ||||||
|         vocabulary_map = tokenizer_config["model"]["vocab"] |  | ||||||
|         vocabulary_map["<image>"] = vocabulary_map[ |  | ||||||
|             "<reserved08707>" |  | ||||||
|         ]  # use a reserved token instead of adding a new one |  | ||||||
|         del vocabulary_map["<reserved08707>"] |  | ||||||
|  |  | ||||||
|         for token in tokenizer_config["added_tokens"]: |  | ||||||
|             if token["content"] == "<reserved08707>": |  | ||||||
|                 token["content"] = "<image>" |  | ||||||
|  |  | ||||||
|     with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f: |  | ||||||
|         json.dump(tokenizer_config, f)  # save the new file to init tokenizer later |  | ||||||
|  |  | ||||||
|     vq_keys_to_replace = [ |  | ||||||
|         ("ch", "base_channels"), |  | ||||||
|         ("out_ch", "out_channels"), |  | ||||||
|         ("n_embed", "num_embeddings"), |  | ||||||
|         ("ch_mult", "channel_multiplier"), |  | ||||||
|         ("double_z", "double_latent"), |  | ||||||
|         ("z_channels", "latent_channels"), |  | ||||||
|     ] |  | ||||||
|     with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file: |  | ||||||
|         vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"] |  | ||||||
|         vq_config.update(**vq_config["ddconfig"]) |  | ||||||
|         for old, new in vq_keys_to_replace: |  | ||||||
|             vq_config[new] = vq_config[old] |  | ||||||
|         del vq_config["ddconfig"] |  | ||||||
|         del vq_config["ckpt_path"] |  | ||||||
|         del vq_config["lossconfig"] |  | ||||||
|  |  | ||||||
|     config = ChameleonConfig( |  | ||||||
|         hidden_size=dim, |  | ||||||
|         intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), |  | ||||||
|         num_attention_heads=params["n_heads"], |  | ||||||
|         num_hidden_layers=params["n_layers"], |  | ||||||
|         rms_norm_eps=params["norm_eps"], |  | ||||||
|         num_key_value_heads=num_key_value_heads, |  | ||||||
|         vocab_size=VOCAB_SIZE, |  | ||||||
|         rope_theta=base, |  | ||||||
|         max_position_embeddings=max_position_embeddings, |  | ||||||
|         model_parallel_size=model_parallel_size, |  | ||||||
|         swin_norm=swin_norm, |  | ||||||
|         vq_config=vq_config, |  | ||||||
|         vocabulary_map=vocabulary_map, |  | ||||||
|     ) |  | ||||||
|     with init_empty_weights(): |  | ||||||
|         model = ChameleonForConditionalGeneration(config) |  | ||||||
|  |  | ||||||
|     model.load_state_dict(state_dict, assign=True, strict=False) |  | ||||||
|     model.save_pretrained(model_path, safe_serialization=True) |  | ||||||
|  |  | ||||||
|     # Load and save the processor |  | ||||||
|     tokenizer = LlamaTokenizerFast( |  | ||||||
|         tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False |  | ||||||
|     ) |  | ||||||
|     tokenizer.sep_token_id = 8710  # assign <reserved08706> to sep so that we can append it after input text |  | ||||||
|     tokenizer.pad_token_id = 1  # assing <pad> to special pad_token |  | ||||||
|     image_processor = ChameleonImageProcessor() |  | ||||||
|     processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer) |  | ||||||
|     processor.save_pretrained(model_path) |  | ||||||
|  |  | ||||||
|     # Make space so we can load the model properly now. |  | ||||||
|     del state_dict |  | ||||||
|     del loaded |  | ||||||
|     del vqgan_state_dict |  | ||||||
|     gc.collect() |  | ||||||
|  |  | ||||||
|     # Short inference on a few examples to check if generation makes sense |  | ||||||
|     # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl |  | ||||||
|     print("Loading the checkpoint in a Chameleon model...") |  | ||||||
|     print("*" * 100) |  | ||||||
|     model = ChameleonForConditionalGeneration.from_pretrained( |  | ||||||
|         model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto" |  | ||||||
|     ) |  | ||||||
|     processor = ChameleonProcessor.from_pretrained(model_path) |  | ||||||
|  |  | ||||||
|     prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist." |  | ||||||
|     image = Image.open( |  | ||||||
|         requests.get( |  | ||||||
|             "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True |  | ||||||
|         ).raw |  | ||||||
|     ) |  | ||||||
|     inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16) |  | ||||||
|     length = inputs.input_ids.shape[1] |  | ||||||
|  |  | ||||||
|     out = model.generate(**inputs, max_new_tokens=40, do_sample=False) |  | ||||||
|     generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] |  | ||||||
|  |  | ||||||
|     print(f"Generation for single-image: {generated_text}") |  | ||||||
|     print("*" * 100) |  | ||||||
|  |  | ||||||
|     # Multi-image example |  | ||||||
|     prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation." |  | ||||||
|     image = Image.open( |  | ||||||
|         requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw |  | ||||||
|     ) |  | ||||||
|     image_2 = Image.open( |  | ||||||
|         requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16) |  | ||||||
|     length = inputs.input_ids.shape[1] |  | ||||||
|     out = model.generate(**inputs, max_new_tokens=50, do_sample=False) |  | ||||||
|     generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] |  | ||||||
|  |  | ||||||
|     print(f"Generation for multi-image: {generated_text}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--input_dir", |  | ||||||
|         help="Location of Chameleon weights", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_size", |  | ||||||
|         choices=["7B", "30B"], |  | ||||||
|         help="" |  | ||||||
|         " models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://github.com/facebookresearch/chameleon", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--output_dir", |  | ||||||
|         help="Location to write HF model", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--test_inference", |  | ||||||
|         action="store_true", |  | ||||||
|         help="Whether to load the model for generation to test it's converted correctly.", |  | ||||||
|     ) |  | ||||||
|     # Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--chameleon_version", |  | ||||||
|         choices=[1], |  | ||||||
|         default=1, |  | ||||||
|         type=int, |  | ||||||
|         help="Version of the Chameleon model to convert", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     write_model( |  | ||||||
|         model_path=args.output_dir, |  | ||||||
|         input_base_path=args.input_dir, |  | ||||||
|         model_size=args.model_size, |  | ||||||
|         chameleon_version=args.chameleon_version, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     main() |  | ||||||
| @ -1289,13 +1289,10 @@ class ChameleonModel(ChameleonPreTrainedModel): | |||||||
|                 "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" |                 "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         if inputs_embeds is None: |  | ||||||
|             inputs_embeds = self.embed_tokens(input_ids) |  | ||||||
|  |  | ||||||
|         if pixel_values is not None: |         if pixel_values is not None: | ||||||
|             image_tokens = self.get_image_tokens(pixel_values) |             image_tokens = self.get_image_tokens(pixel_values) | ||||||
|             special_image_mask = input_ids == self.vocabulary_mapping.image_token_id |             special_image_mask = input_ids == self.vocabulary_mapping.image_token_id | ||||||
|             if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel(): |             if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): | ||||||
|                 n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() |                 n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() | ||||||
|                 n_image_features = image_tokens.shape[0] * image_tokens.shape[1] |                 n_image_features = image_tokens.shape[0] * image_tokens.shape[1] | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
| @ -1304,6 +1301,9 @@ class ChameleonModel(ChameleonPreTrainedModel): | |||||||
|             image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) |             image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) | ||||||
|             input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) |             input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) | ||||||
|  |  | ||||||
|  |         if inputs_embeds is None: | ||||||
|  |             inputs_embeds = self.embed_tokens(input_ids) | ||||||
|  |  | ||||||
|         # torch.jit.trace() doesn't support cache objects in the output |         # torch.jit.trace() doesn't support cache objects in the output | ||||||
|         if use_cache and past_key_values is None and not torch.jit.is_tracing(): |         if use_cache and past_key_values is None and not torch.jit.is_tracing(): | ||||||
|             past_key_values = DynamicCache() |             past_key_values = DynamicCache() | ||||||
|  | |||||||
| @ -1,134 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import ChineseCLIPConfig, ChineseCLIPModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_attn_layer(hf_attn_layer, pt_weights, prefix): |  | ||||||
|     q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0) |  | ||||||
|     q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0) |  | ||||||
|  |  | ||||||
|     out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"] |  | ||||||
|     out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"] |  | ||||||
|  |  | ||||||
|     hf_attn_layer.q_proj.weight.data = q_proj |  | ||||||
|     hf_attn_layer.q_proj.bias.data = q_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.k_proj.weight.data = k_proj |  | ||||||
|     hf_attn_layer.k_proj.bias.data = k_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.v_proj.weight.data = v_proj |  | ||||||
|     hf_attn_layer.v_proj.bias.data = v_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.out_proj.weight.data = out_proj_weights |  | ||||||
|     hf_attn_layer.out_proj.bias.data = out_proj_bias |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_mlp(hf_mlp, pt_weights, prefix): |  | ||||||
|     copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc") |  | ||||||
|     copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_linear(hf_linear, pt_weights, prefix): |  | ||||||
|     hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data |  | ||||||
|     hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_layer(hf_layer, pt_weights, prefix): |  | ||||||
|     # copy layer norms |  | ||||||
|     copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1") |  | ||||||
|     copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2") |  | ||||||
|  |  | ||||||
|     # copy MLP |  | ||||||
|     copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp") |  | ||||||
|  |  | ||||||
|     # copy attn |  | ||||||
|     copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_layers(hf_layers, pt_weights, prefix): |  | ||||||
|     for layer_id, hf_layer in enumerate(hf_layers): |  | ||||||
|         copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_text_model_and_projection(hf_model, pt_weights): |  | ||||||
|     # copy projection |  | ||||||
|     hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T |  | ||||||
|  |  | ||||||
|     # copy text encoder |  | ||||||
|     for name, param in hf_model.text_model.named_parameters(): |  | ||||||
|         param.data = pt_weights[f"bert.{name}"].data |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_vision_model_and_projection(hf_model, pt_weights): |  | ||||||
|     # copy projection |  | ||||||
|     hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T |  | ||||||
|  |  | ||||||
|     # copy layer norms |  | ||||||
|     copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre") |  | ||||||
|     copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post") |  | ||||||
|  |  | ||||||
|     # copy embeddings |  | ||||||
|     hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data |  | ||||||
|     hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data |  | ||||||
|     hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data |  | ||||||
|  |  | ||||||
|     # copy encoder |  | ||||||
|     copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to transformers design. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size." |  | ||||||
|     config = ChineseCLIPConfig.from_pretrained(config_path) |  | ||||||
|  |  | ||||||
|     hf_model = ChineseCLIPModel(config).eval() |  | ||||||
|  |  | ||||||
|     pt_weights = torch.load(checkpoint_path, map_location="cpu")["state_dict"] |  | ||||||
|     pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()} |  | ||||||
|  |  | ||||||
|     copy_text_model_and_projection(hf_model, pt_weights) |  | ||||||
|     copy_vision_model_and_projection(hf_model, pt_weights) |  | ||||||
|     hf_model.logit_scale.data = pt_weights["logit_scale"].data |  | ||||||
|  |  | ||||||
|     hf_model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the output folder storing converted hf PyTorch model.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) |  | ||||||
|     print("The conversion is finished!") |  | ||||||
| @ -1,133 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| from laion_clap import CLAP_Module |  | ||||||
|  |  | ||||||
| from transformers import AutoFeatureExtractor, ClapConfig, ClapModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| KEYS_TO_MODIFY_MAPPING = { |  | ||||||
|     "text_branch": "text_model", |  | ||||||
|     "audio_branch": "audio_model.audio_encoder", |  | ||||||
|     "attn": "attention.self", |  | ||||||
|     "self.proj": "output.dense", |  | ||||||
|     "attention.self_mask": "attn_mask", |  | ||||||
|     "mlp.fc1": "intermediate.dense", |  | ||||||
|     "mlp.fc2": "output.dense", |  | ||||||
|     "norm1": "layernorm_before", |  | ||||||
|     "norm2": "layernorm_after", |  | ||||||
|     "bn0": "batch_norm", |  | ||||||
| } |  | ||||||
|  |  | ||||||
| processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def init_clap(checkpoint_path, model_type, enable_fusion=False): |  | ||||||
|     model = CLAP_Module( |  | ||||||
|         amodel=model_type, |  | ||||||
|         enable_fusion=enable_fusion, |  | ||||||
|     ) |  | ||||||
|     model.load_ckpt(checkpoint_path) |  | ||||||
|     return model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_config_from_original(clap_model): |  | ||||||
|     audio_config = { |  | ||||||
|         "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim, |  | ||||||
|         "depths": clap_model.model.audio_branch.depths, |  | ||||||
|         "hidden_size": clap_model.model.audio_projection[0].in_features, |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features} |  | ||||||
|  |  | ||||||
|     return ClapConfig(audio_config=audio_config, text_config=text_config) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_state_dict(state_dict): |  | ||||||
|     model_state_dict = {} |  | ||||||
|  |  | ||||||
|     sequential_layers_pattern = r".*sequential.(\d+).*" |  | ||||||
|     text_projection_pattern = r".*_projection.(\d+).*" |  | ||||||
|  |  | ||||||
|     for key, value in state_dict.items(): |  | ||||||
|         # check if any key needs to be modified |  | ||||||
|         for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): |  | ||||||
|             if key_to_modify in key: |  | ||||||
|                 key = key.replace(key_to_modify, new_key) |  | ||||||
|  |  | ||||||
|         if re.match(sequential_layers_pattern, key): |  | ||||||
|             # replace sequential layers with list |  | ||||||
|             sequential_layer = re.match(sequential_layers_pattern, key).group(1) |  | ||||||
|  |  | ||||||
|             key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.") |  | ||||||
|         elif re.match(text_projection_pattern, key): |  | ||||||
|             projecton_layer = int(re.match(text_projection_pattern, key).group(1)) |  | ||||||
|  |  | ||||||
|             # Because in CLAP they use `nn.Sequential`... |  | ||||||
|             transformers_projection_layer = 1 if projecton_layer == 0 else 2 |  | ||||||
|  |  | ||||||
|             key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.") |  | ||||||
|  |  | ||||||
|         if "audio" and "qkv" in key: |  | ||||||
|             # split qkv into query key and value |  | ||||||
|             mixed_qkv = value |  | ||||||
|             qkv_dim = mixed_qkv.size(0) // 3 |  | ||||||
|  |  | ||||||
|             query_layer = mixed_qkv[:qkv_dim] |  | ||||||
|             key_layer = mixed_qkv[qkv_dim : qkv_dim * 2] |  | ||||||
|             value_layer = mixed_qkv[qkv_dim * 2 :] |  | ||||||
|  |  | ||||||
|             model_state_dict[key.replace("qkv", "query")] = query_layer |  | ||||||
|             model_state_dict[key.replace("qkv", "key")] = key_layer |  | ||||||
|             model_state_dict[key.replace("qkv", "value")] = value_layer |  | ||||||
|         else: |  | ||||||
|             model_state_dict[key] = value |  | ||||||
|  |  | ||||||
|     return model_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False): |  | ||||||
|     clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion) |  | ||||||
|  |  | ||||||
|     clap_model.eval() |  | ||||||
|     state_dict = clap_model.model.state_dict() |  | ||||||
|     state_dict = rename_state_dict(state_dict) |  | ||||||
|  |  | ||||||
|     transformers_config = get_config_from_original(clap_model) |  | ||||||
|     transformers_config.audio_config.enable_fusion = enable_fusion |  | ||||||
|     model = ClapModel(transformers_config) |  | ||||||
|  |  | ||||||
|     # ignore the spectrogram embedding layer |  | ||||||
|     model.load_state_dict(state_dict, strict=False) |  | ||||||
|  |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     transformers_config.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") |  | ||||||
|     parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") |  | ||||||
|     parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not") |  | ||||||
|     parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_clap_checkpoint( |  | ||||||
|         args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion |  | ||||||
|     ) |  | ||||||
| @ -1,156 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from clip import load |  | ||||||
|  |  | ||||||
| from transformers import CLIPConfig, CLIPModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_attn_layer(hf_attn_layer, pt_attn_layer): |  | ||||||
|     q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) |  | ||||||
|     q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) |  | ||||||
|  |  | ||||||
|     out_proj_weights = pt_attn_layer.out_proj.weight |  | ||||||
|     out_proj_bias = pt_attn_layer.out_proj.bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.q_proj.weight.data = q_proj |  | ||||||
|     hf_attn_layer.q_proj.bias.data = q_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.k_proj.weight.data = k_proj |  | ||||||
|     hf_attn_layer.k_proj.bias.data = k_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.v_proj.weight.data = v_proj |  | ||||||
|     hf_attn_layer.v_proj.bias.data = v_proj_bias |  | ||||||
|  |  | ||||||
|     hf_attn_layer.out_proj.weight = out_proj_weights |  | ||||||
|     hf_attn_layer.out_proj.bias = out_proj_bias |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_mlp(hf_mlp, pt_mlp): |  | ||||||
|     copy_linear(hf_mlp.fc1, pt_mlp.c_fc) |  | ||||||
|     copy_linear(hf_mlp.fc2, pt_mlp.c_proj) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_linear(hf_linear, pt_linear): |  | ||||||
|     hf_linear.weight = pt_linear.weight |  | ||||||
|     hf_linear.bias = pt_linear.bias |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_layer(hf_layer, pt_layer): |  | ||||||
|     # copy layer norms |  | ||||||
|     copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) |  | ||||||
|     copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) |  | ||||||
|  |  | ||||||
|     # copy MLP |  | ||||||
|     copy_mlp(hf_layer.mlp, pt_layer.mlp) |  | ||||||
|  |  | ||||||
|     # copy attn |  | ||||||
|     copy_attn_layer(hf_layer.self_attn, pt_layer.attn) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_layers(hf_layers, pt_layers): |  | ||||||
|     for hf_layer, pt_layer in zip(hf_layers, pt_layers): |  | ||||||
|         copy_layer(hf_layer, pt_layer) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_encoder(hf_encoder, pt_model): |  | ||||||
|     # copy  embeds |  | ||||||
|     hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight |  | ||||||
|     hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding |  | ||||||
|  |  | ||||||
|     # copy layer norm |  | ||||||
|     copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) |  | ||||||
|  |  | ||||||
|     # copy hidden layers |  | ||||||
|     copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_text_model_and_projection(hf_model, pt_model): |  | ||||||
|     # copy projection |  | ||||||
|     hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous() |  | ||||||
|  |  | ||||||
|     # copy text encoder |  | ||||||
|     copy_encoder(hf_model.text_model, pt_model) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def copy_vison_model_and_projection(hf_model, pt_model): |  | ||||||
|     # copy projection |  | ||||||
|     hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous() |  | ||||||
|  |  | ||||||
|     # copy layer norms |  | ||||||
|     copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre) |  | ||||||
|     copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) |  | ||||||
|  |  | ||||||
|     # copy embeds |  | ||||||
|     hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data |  | ||||||
|     hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding |  | ||||||
|     hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data |  | ||||||
|  |  | ||||||
|     # copy encoder |  | ||||||
|     copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to transformers design. |  | ||||||
|     """ |  | ||||||
|     if config_path is not None: |  | ||||||
|         config = CLIPConfig.from_pretrained(config_path) |  | ||||||
|     else: |  | ||||||
|         config = CLIPConfig(projection_dim=512, text_config={}, vision_config={}) |  | ||||||
|  |  | ||||||
|     hf_model = CLIPModel(config).eval() |  | ||||||
|  |  | ||||||
|     pt_model, _ = load(checkpoint_path, device="cpu", jit=False) |  | ||||||
|     pt_model = pt_model.eval() |  | ||||||
|  |  | ||||||
|     copy_text_model_and_projection(hf_model, pt_model) |  | ||||||
|     copy_vison_model_and_projection(hf_model, pt_model) |  | ||||||
|     hf_model.logit_scale = pt_model.logit_scale |  | ||||||
|  |  | ||||||
|     # Use `eos_token` so the example is more meaningful |  | ||||||
|     input_ids = torch.tensor( |  | ||||||
|         [ |  | ||||||
|             [config.text_config.bos_token_id] |  | ||||||
|             + list(range(3, 77)) |  | ||||||
|             + [config.text_config.eos_token_id] |  | ||||||
|             + [config.text_config.pad_token_id] |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|     pixel_values = torch.randn(1, 3, 224, 224) |  | ||||||
|  |  | ||||||
|     hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True) |  | ||||||
|     hf_logits_per_image = hf_outputs.logits_per_image |  | ||||||
|     hf_logits_per_text = hf_outputs.logits_per_text |  | ||||||
|     pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) |  | ||||||
|  |  | ||||||
|     assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) |  | ||||||
|     assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3) |  | ||||||
|  |  | ||||||
|     hf_model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint") |  | ||||||
|     parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) |  | ||||||
| @ -1,264 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     CLIPSegConfig, |  | ||||||
|     CLIPSegForImageSegmentation, |  | ||||||
|     CLIPSegProcessor, |  | ||||||
|     CLIPSegTextConfig, |  | ||||||
|     CLIPSegVisionConfig, |  | ||||||
|     CLIPTokenizer, |  | ||||||
|     ViTImageProcessor, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_clipseg_config(model_name): |  | ||||||
|     text_config = CLIPSegTextConfig() |  | ||||||
|     vision_config = CLIPSegVisionConfig(patch_size=16) |  | ||||||
|  |  | ||||||
|     use_complex_transposed_convolution = True if "refined" in model_name else False |  | ||||||
|     reduce_dim = 16 if "rd16" in model_name else 64 |  | ||||||
|  |  | ||||||
|     config = CLIPSegConfig.from_text_vision_configs( |  | ||||||
|         text_config, |  | ||||||
|         vision_config, |  | ||||||
|         use_complex_transposed_convolution=use_complex_transposed_convolution, |  | ||||||
|         reduce_dim=reduce_dim, |  | ||||||
|     ) |  | ||||||
|     return config |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     # update prefixes |  | ||||||
|     if "clip_model" in name: |  | ||||||
|         name = name.replace("clip_model", "clip") |  | ||||||
|     if "transformer" in name: |  | ||||||
|         if "visual" in name: |  | ||||||
|             name = name.replace("visual.transformer", "vision_model") |  | ||||||
|         else: |  | ||||||
|             name = name.replace("transformer", "text_model") |  | ||||||
|     if "resblocks" in name: |  | ||||||
|         name = name.replace("resblocks", "encoder.layers") |  | ||||||
|     if "ln_1" in name: |  | ||||||
|         name = name.replace("ln_1", "layer_norm1") |  | ||||||
|     if "ln_2" in name: |  | ||||||
|         name = name.replace("ln_2", "layer_norm2") |  | ||||||
|     if "c_fc" in name: |  | ||||||
|         name = name.replace("c_fc", "fc1") |  | ||||||
|     if "c_proj" in name: |  | ||||||
|         name = name.replace("c_proj", "fc2") |  | ||||||
|     if "attn" in name and "self" not in name: |  | ||||||
|         name = name.replace("attn", "self_attn") |  | ||||||
|     # text encoder |  | ||||||
|     if "token_embedding" in name: |  | ||||||
|         name = name.replace("token_embedding", "text_model.embeddings.token_embedding") |  | ||||||
|     if "positional_embedding" in name and "visual" not in name: |  | ||||||
|         name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight") |  | ||||||
|     if "ln_final" in name: |  | ||||||
|         name = name.replace("ln_final", "text_model.final_layer_norm") |  | ||||||
|     # vision encoder |  | ||||||
|     if "visual.class_embedding" in name: |  | ||||||
|         name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding") |  | ||||||
|     if "visual.conv1" in name: |  | ||||||
|         name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding") |  | ||||||
|     if "visual.positional_embedding" in name: |  | ||||||
|         name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight") |  | ||||||
|     if "visual.ln_pre" in name: |  | ||||||
|         name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm") |  | ||||||
|     if "visual.ln_post" in name: |  | ||||||
|         name = name.replace("visual.ln_post", "vision_model.post_layernorm") |  | ||||||
|     # projection layers |  | ||||||
|     if "visual.proj" in name: |  | ||||||
|         name = name.replace("visual.proj", "visual_projection.weight") |  | ||||||
|     if "text_projection" in name: |  | ||||||
|         name = name.replace("text_projection", "text_projection.weight") |  | ||||||
|     # decoder |  | ||||||
|     if "trans_conv" in name: |  | ||||||
|         name = name.replace("trans_conv", "transposed_convolution") |  | ||||||
|     if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name: |  | ||||||
|         name = "decoder." + name |  | ||||||
|     if "blocks" in name: |  | ||||||
|         name = name.replace("blocks", "decoder.layers") |  | ||||||
|     if "linear1" in name: |  | ||||||
|         name = name.replace("linear1", "mlp.fc1") |  | ||||||
|     if "linear2" in name: |  | ||||||
|         name = name.replace("linear2", "mlp.fc2") |  | ||||||
|     if "norm1" in name and "layer_" not in name: |  | ||||||
|         name = name.replace("norm1", "layer_norm1") |  | ||||||
|     if "norm2" in name and "layer_" not in name: |  | ||||||
|         name = name.replace("norm2", "layer_norm2") |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_state_dict(orig_state_dict, config): |  | ||||||
|     for key in orig_state_dict.copy().keys(): |  | ||||||
|         val = orig_state_dict.pop(key) |  | ||||||
|  |  | ||||||
|         if key.startswith("clip_model") and "attn.in_proj" in key: |  | ||||||
|             key_split = key.split(".") |  | ||||||
|             if "visual" in key: |  | ||||||
|                 layer_num = int(key_split[4]) |  | ||||||
|                 dim = config.vision_config.hidden_size |  | ||||||
|                 prefix = "vision_model" |  | ||||||
|             else: |  | ||||||
|                 layer_num = int(key_split[3]) |  | ||||||
|                 dim = config.text_config.hidden_size |  | ||||||
|                 prefix = "text_model" |  | ||||||
|  |  | ||||||
|             if "weight" in key: |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :] |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[ |  | ||||||
|                     dim : dim * 2, : |  | ||||||
|                 ] |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :] |  | ||||||
|             else: |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim] |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2] |  | ||||||
|                 orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:] |  | ||||||
|         elif "self_attn" in key and "out_proj" not in key: |  | ||||||
|             key_split = key.split(".") |  | ||||||
|             layer_num = int(key_split[1]) |  | ||||||
|             dim = config.reduce_dim |  | ||||||
|             if "weight" in key: |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :] |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :] |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :] |  | ||||||
|             else: |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim] |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2] |  | ||||||
|                 orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:] |  | ||||||
|         else: |  | ||||||
|             new_name = rename_key(key) |  | ||||||
|             if "visual_projection" in new_name or "text_projection" in new_name: |  | ||||||
|                 val = val.T |  | ||||||
|             orig_state_dict[new_name] = val |  | ||||||
|  |  | ||||||
|     return orig_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     image = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return image |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub): |  | ||||||
|     config = get_clipseg_config(model_name) |  | ||||||
|     model = CLIPSegForImageSegmentation(config) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     state_dict = torch.load(checkpoint_path, map_location="cpu") |  | ||||||
|  |  | ||||||
|     # remove some keys |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         if key.startswith("model"): |  | ||||||
|             state_dict.pop(key, None) |  | ||||||
|  |  | ||||||
|     # rename some keys |  | ||||||
|     state_dict = convert_state_dict(state_dict, config) |  | ||||||
|     missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |  | ||||||
|  |  | ||||||
|     if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]: |  | ||||||
|         raise ValueError("Missing keys that are not expected: {}".format(missing_keys)) |  | ||||||
|     if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]: |  | ||||||
|         raise ValueError(f"Unexpected keys: {unexpected_keys}") |  | ||||||
|  |  | ||||||
|     image_processor = ViTImageProcessor(size=352) |  | ||||||
|     tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |  | ||||||
|     processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer) |  | ||||||
|  |  | ||||||
|     image = prepare_img() |  | ||||||
|     text = ["a glass", "something to fill", "wood", "a jar"] |  | ||||||
|  |  | ||||||
|     inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt") |  | ||||||
|  |  | ||||||
|     with torch.no_grad(): |  | ||||||
|         outputs = model(**inputs) |  | ||||||
|  |  | ||||||
|     # verify values |  | ||||||
|     expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645]) |  | ||||||
|     expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328]) |  | ||||||
|     if model_name == "clipseg-rd64-refined": |  | ||||||
|         expected_masks_slice = torch.tensor( |  | ||||||
|             [[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]] |  | ||||||
|         ) |  | ||||||
|     elif model_name == "clipseg-rd64": |  | ||||||
|         expected_masks_slice = torch.tensor( |  | ||||||
|             [[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]] |  | ||||||
|         ) |  | ||||||
|     elif model_name == "clipseg-rd16": |  | ||||||
|         expected_masks_slice = torch.tensor( |  | ||||||
|             [[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]] |  | ||||||
|         ) |  | ||||||
|     else: |  | ||||||
|         raise ValueError(f"Model name {model_name} not supported.") |  | ||||||
|  |  | ||||||
|     assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3) |  | ||||||
|     assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3) |  | ||||||
|     assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3) |  | ||||||
|     print("Looks ok!") |  | ||||||
|  |  | ||||||
|     if pytorch_dump_folder_path is not None: |  | ||||||
|         print(f"Saving model and processor to {pytorch_dump_folder_path}") |  | ||||||
|         model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         print(f"Pushing model and processor for {model_name} to the hub") |  | ||||||
|         model.push_to_hub(f"CIDAS/{model_name}") |  | ||||||
|         processor.push_to_hub(f"CIDAS/{model_name}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="clipseg-rd64", |  | ||||||
|         type=str, |  | ||||||
|         choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"], |  | ||||||
|         help=( |  | ||||||
|             "Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning" |  | ||||||
|             " reduce dimension)" |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_path", |  | ||||||
|         default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth", |  | ||||||
|         type=str, |  | ||||||
|         help=( |  | ||||||
|             "Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and" |  | ||||||
|             " the decoder weights." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub) |  | ||||||
| @ -1,234 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| """ |  | ||||||
| Weights conversion script for CLVP |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
|  |  | ||||||
| from transformers import ClvpConfig, ClvpModelForConditionalGeneration |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _MODELS = { |  | ||||||
|     "clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth", |  | ||||||
|     "decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth", |  | ||||||
| } |  | ||||||
|  |  | ||||||
| dim = 1024 |  | ||||||
| sub_dim = dim // 16 |  | ||||||
|  |  | ||||||
| CLVP_ENCODERS_MAPPING = { |  | ||||||
|     "text_transformer.transformer.attn_layers": "text_encoder_model", |  | ||||||
|     "speech_transformer.transformer.attn_layers": "speech_encoder_model", |  | ||||||
|     "text_transformer.transformer.norm": "text_encoder_model.final_layer_norm", |  | ||||||
|     "speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm", |  | ||||||
|     "to_text_latent": "text_encoder_model.projection", |  | ||||||
|     "to_speech_latent": "speech_encoder_model.projection", |  | ||||||
|     "text_emb": "text_encoder_model.token_embedding", |  | ||||||
|     "speech_emb": "speech_encoder_model.token_embedding", |  | ||||||
|     "1.wrap.net.0": "mlp.fc1", |  | ||||||
|     "1.wrap.net.3": "mlp.fc2", |  | ||||||
|     "1.wrap": "self_attn", |  | ||||||
|     "to_out": "out_proj", |  | ||||||
|     "to_q": "q_proj", |  | ||||||
|     "to_k": "k_proj", |  | ||||||
|     "to_v": "v_proj", |  | ||||||
|     "temperature": "logit_scale", |  | ||||||
| } |  | ||||||
|  |  | ||||||
| CLVP_DECODER_MAPPING = { |  | ||||||
|     "conditioning_encoder.init": "conditioning_encoder.mel_conv", |  | ||||||
|     "conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks", |  | ||||||
|     "mel_attn_blocks": "group_norms", |  | ||||||
|     ".norm.weight": ".weight", |  | ||||||
|     ".norm.bias": ".bias", |  | ||||||
|     "text_embedding": "conditioning_encoder.text_token_embedding", |  | ||||||
|     "text_pos_embedding.emb": "conditioning_encoder.text_position_embedding", |  | ||||||
|     "final_norm": "speech_decoder_model.final_norm", |  | ||||||
|     "mel_head": "speech_decoder_model.lm_head", |  | ||||||
|     "gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm", |  | ||||||
|     "mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer", |  | ||||||
|     "mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer", |  | ||||||
|     "gpt.h": "speech_decoder_model.model.decoder.layers", |  | ||||||
|     "ln_1": "input_layernorm", |  | ||||||
|     "ln_2": "post_attention_layernorm", |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def update_index(present_index): |  | ||||||
|     if present_index % 2 == 0: |  | ||||||
|         return int(present_index / 2) |  | ||||||
|     else: |  | ||||||
|         return int((present_index - 1) / 2) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_encoder_weights(original_weights): |  | ||||||
|     converted_weights = {} |  | ||||||
|     original_weights_keys = sorted(original_weights.keys()) |  | ||||||
|     for original_key in original_weights_keys: |  | ||||||
|         updated_key = original_key |  | ||||||
|         # for input_rmsnorm.weight and post_attention_rmsnorm.weight |  | ||||||
|         if "0.0.g" in updated_key: |  | ||||||
|             present_index = updated_key.split(".")[4] |  | ||||||
|             if int(present_index) % 2 == 0: |  | ||||||
|                 updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight") |  | ||||||
|             else: |  | ||||||
|                 updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight") |  | ||||||
|  |  | ||||||
|         if "transformer.attn_layers.layers" in updated_key: |  | ||||||
|             present_index = updated_key.split(".")[4] |  | ||||||
|             updated_index = update_index(int(present_index)) |  | ||||||
|             updated_key = updated_key.replace( |  | ||||||
|                 f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}" |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         for k, v in CLVP_ENCODERS_MAPPING.items(): |  | ||||||
|             if k in updated_key: |  | ||||||
|                 updated_key = updated_key.replace(k, v) |  | ||||||
|  |  | ||||||
|         converted_weights[updated_key] = original_weights.pop(original_key) |  | ||||||
|  |  | ||||||
|     return converted_weights |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_decoder_weights(original_weights): |  | ||||||
|     converted_weights = {} |  | ||||||
|     original_weights_keys = sorted(original_weights.keys()) |  | ||||||
|     for original_key in original_weights_keys: |  | ||||||
|         updated_key = original_key |  | ||||||
|         if len(updated_key.split(".")) > 3: |  | ||||||
|             index, attr = updated_key.split(".")[2], updated_key.split(".")[-1] |  | ||||||
|  |  | ||||||
|         # for decoder attention |  | ||||||
|         if "attn.c_attn" in updated_key: |  | ||||||
|             if attr == "weight": |  | ||||||
|                 slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0) |  | ||||||
|             else: |  | ||||||
|                 slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0) |  | ||||||
|             converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1 |  | ||||||
|             converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2 |  | ||||||
|             converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3 |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         if "attn.c_proj" in updated_key: |  | ||||||
|             converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = ( |  | ||||||
|                 original_weights[updated_key].squeeze(-1).T |  | ||||||
|             ) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key: |  | ||||||
|             original_weights.pop(updated_key) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         # conditional encoder attention |  | ||||||
|         if "qkv" in updated_key: |  | ||||||
|             if attr == "weight": |  | ||||||
|                 slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0) |  | ||||||
|             else: |  | ||||||
|                 slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0) |  | ||||||
|  |  | ||||||
|             indices = torch.arange(dim) |  | ||||||
|             index1, index2, index3 = ( |  | ||||||
|                 indices.unfold(0, sub_dim, sub_dim * 3).flatten(), |  | ||||||
|                 indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(), |  | ||||||
|                 indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(), |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate( |  | ||||||
|                 [slice1[index1], slice2[index3], slice3[index2]], |  | ||||||
|                 axis=0, |  | ||||||
|             ) |  | ||||||
|             converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate( |  | ||||||
|                 [slice1[index2], slice2[index1], slice3[index3]], |  | ||||||
|                 axis=0, |  | ||||||
|             ) |  | ||||||
|             converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate( |  | ||||||
|                 [slice1[index3], slice2[index2], slice3[index1]], |  | ||||||
|                 axis=0, |  | ||||||
|             ) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         if "proj_out" in updated_key: |  | ||||||
|             converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[ |  | ||||||
|                 updated_key |  | ||||||
|             ].squeeze(-1) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         for k, v in CLVP_DECODER_MAPPING.items(): |  | ||||||
|             if k in updated_key: |  | ||||||
|                 updated_key = updated_key.replace(k, v) |  | ||||||
|  |  | ||||||
|         converted_weights[updated_key] = original_weights.pop(original_key) |  | ||||||
|  |  | ||||||
|     return converted_weights |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _download(url: str, root: str): |  | ||||||
|     repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}" |  | ||||||
|     filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}" |  | ||||||
|     hf_hub_download( |  | ||||||
|         repo_id=repo_id, |  | ||||||
|         filename=filename, |  | ||||||
|         force_filename=root, |  | ||||||
|         local_dir_use_symlinks=False, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path): |  | ||||||
|     converted_checkpoint = {} |  | ||||||
|  |  | ||||||
|     for each_model_name, each_model_url in _MODELS.items(): |  | ||||||
|         each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1]) |  | ||||||
|         if not os.path.exists(each_model_path): |  | ||||||
|             print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}") |  | ||||||
|             _download(url=each_model_url, root=each_model_path) |  | ||||||
|  |  | ||||||
|         if each_model_name == "clvp": |  | ||||||
|             clvp_checkpoint = torch.load(each_model_path, map_location="cpu") |  | ||||||
|         else: |  | ||||||
|             decoder_checkpoint = torch.load(each_model_path, map_location="cpu") |  | ||||||
|  |  | ||||||
|     # Converting the weights |  | ||||||
|     converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint)) |  | ||||||
|     converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint)) |  | ||||||
|  |  | ||||||
|     config = ClvpConfig.from_pretrained("susnato/clvp_dev") |  | ||||||
|     model = ClvpModelForConditionalGeneration(config) |  | ||||||
|  |  | ||||||
|     model.load_state_dict(converted_checkpoint, strict=True) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     print(f"Model saved at {pytorch_dump_folder_path}!") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)" |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the output PyTorch model. (Please enter full path)", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,214 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2024 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """ |  | ||||||
| Convert ColPali weights from the original repository to the HF model format. |  | ||||||
|  |  | ||||||
| Original repository: https://github.com/illuin-tech/colpali. |  | ||||||
|  |  | ||||||
| NOTE: This script was originally run using `torch==2.5.1` and with: |  | ||||||
|  |  | ||||||
| ```bash |  | ||||||
| python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ |  | ||||||
|     --model_id vidore/colpali-v1.2-merged \ |  | ||||||
|     --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \ |  | ||||||
|     --original_vlm_name_or_path google/paligemma-3b-mix-448 \ |  | ||||||
|     --output_dir vidore/colpali-v1.2-hf-internal \ |  | ||||||
|     --push_to_hub |  | ||||||
|  |  | ||||||
| python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ |  | ||||||
|     --model_id vidore/colpali-v1.3-merged \ |  | ||||||
|     --revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \ |  | ||||||
|     --original_vlm_name_or_path google/paligemma-3b-mix-448 \ |  | ||||||
|     --output_dir vidore/colpali-v1.3-hf \ |  | ||||||
|     --push_to_hub |  | ||||||
| ``` |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import glob |  | ||||||
| from pathlib import Path |  | ||||||
| from typing import Any, Dict, Optional |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import snapshot_download |  | ||||||
| from safetensors import safe_open |  | ||||||
|  |  | ||||||
| from transformers import AutoConfig |  | ||||||
| from transformers.models.colpali import ColPaliForRetrieval |  | ||||||
| from transformers.models.colpali.configuration_colpali import ColPaliConfig |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ORIGINAL_DTYPE = torch.bfloat16 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: |  | ||||||
|     new_state_dict = {} |  | ||||||
|     for key, value in state_dict.items(): |  | ||||||
|         new_key = key |  | ||||||
|         if key.startswith("custom_text_proj"): |  | ||||||
|             new_key = key.replace("custom_text_proj", "embedding_proj_layer") |  | ||||||
|         if key.startswith("model."): |  | ||||||
|             new_key = key.replace("model.", "vlm.", 1) |  | ||||||
|         new_state_dict[new_key] = value |  | ||||||
|     return new_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]: |  | ||||||
|     directory_path = snapshot_download( |  | ||||||
|         repo_id=model_id, |  | ||||||
|         revision=revision, |  | ||||||
|         allow_patterns=["*.safetensors"], |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     original_state_dict = {} |  | ||||||
|     for path in glob.glob(f"{directory_path}/*"): |  | ||||||
|         if path.endswith(".safetensors"): |  | ||||||
|             with safe_open(path, framework="pt", device="cpu") as f: |  | ||||||
|                 for key in f.keys(): |  | ||||||
|                     original_state_dict[key] = f.get_tensor(key) |  | ||||||
|  |  | ||||||
|     # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict. |  | ||||||
|     if "lm_head.weight" not in original_state_dict: |  | ||||||
|         original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[ |  | ||||||
|             "model.language_model.model.embed_tokens.weight" |  | ||||||
|         ].clone() |  | ||||||
|  |  | ||||||
|     return original_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_colpali_weights_to_hf( |  | ||||||
|     model_id: str, |  | ||||||
|     output_dir: str, |  | ||||||
|     push_to_hub: bool, |  | ||||||
|     revision: Optional[str] = None, |  | ||||||
|     original_vlm_name_or_path: Optional[str] = None, |  | ||||||
| ): |  | ||||||
|     # Load the original model data |  | ||||||
|     original_config = AutoConfig.from_pretrained( |  | ||||||
|         model_id, |  | ||||||
|         revision=revision, |  | ||||||
|     ) |  | ||||||
|     if original_vlm_name_or_path is not None: |  | ||||||
|         original_config._name_or_path = original_vlm_name_or_path |  | ||||||
|     if hasattr(original_config, "architectures"): |  | ||||||
|         delattr(original_config, "architectures") |  | ||||||
|  |  | ||||||
|     original_state_dict = load_original_state_dict(model_id, revision=revision) |  | ||||||
|  |  | ||||||
|     # Format the state_dict keys |  | ||||||
|     original_state_dict = rename_state_dict_keys(original_state_dict) |  | ||||||
|  |  | ||||||
|     # Create the new config |  | ||||||
|     config = ColPaliConfig( |  | ||||||
|         vlm_config=original_config, |  | ||||||
|         embedding_dim=128,  # hardcoded in the original model |  | ||||||
|     ) |  | ||||||
|     config.model_type = "colpali" |  | ||||||
|     config.is_composition = False |  | ||||||
|  |  | ||||||
|     # Load the untrained model |  | ||||||
|     model = ColPaliForRetrieval(config=config).to("cpu").eval() |  | ||||||
|     print("Created model with new config and randomly initialized weights") |  | ||||||
|  |  | ||||||
|     # NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision. |  | ||||||
|     # There are two ways to set the model's dtype: |  | ||||||
|     # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision. |  | ||||||
|     # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision. |  | ||||||
|     # The following snippet allows a fine-grained control over the model's dtype, making sure that all |  | ||||||
|     # the new weights' dtypes match the original model. |  | ||||||
|     for param in model.parameters(): |  | ||||||
|         param.data = param.data.to(ORIGINAL_DTYPE) |  | ||||||
|     print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`") |  | ||||||
|  |  | ||||||
|     # Load the original weights |  | ||||||
|     model.load_state_dict(original_state_dict) |  | ||||||
|     print("Loaded original model weights") |  | ||||||
|  |  | ||||||
|     # Tie the weights (following ColPali's `__init__`` step) |  | ||||||
|     if model.vlm.language_model._tied_weights_keys is not None: |  | ||||||
|         model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys] |  | ||||||
|  |  | ||||||
|     # Sanity check: ensure all keys are the same |  | ||||||
|     state_dict_keys_old = set(original_state_dict.keys()) |  | ||||||
|     state_dict_keys_new = set(model.state_dict().keys()) |  | ||||||
|     disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new) |  | ||||||
|     if disjoint_keys: |  | ||||||
|         raise ValueError(f"Incompatible keys: {disjoint_keys}") |  | ||||||
|  |  | ||||||
|     # Save the model |  | ||||||
|     if push_to_hub: |  | ||||||
|         model.push_to_hub(output_dir, private=True) |  | ||||||
|         print(f"Model pushed to the hub at `{output_dir}`") |  | ||||||
|     else: |  | ||||||
|         Path(output_dir).mkdir(exist_ok=True, parents=True) |  | ||||||
|         model.save_pretrained(output_dir) |  | ||||||
|         print(f"Model saved to `{output_dir}`") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser( |  | ||||||
|         description=""" |  | ||||||
|         This script converts the original ColPali model to the HF model format. |  | ||||||
|  |  | ||||||
|         Example usage: |  | ||||||
|         ```bash |  | ||||||
|         python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \ |  | ||||||
|             --model_id vidore/colpali-v1.2-merged \ |  | ||||||
|             --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \ |  | ||||||
|             --original_vlm_name_or_path google/paligemma-3b-mix-448 \ |  | ||||||
|             --output_dir vidore/colpali-v1.2-hf \ |  | ||||||
|             --push_to_hub |  | ||||||
|         ``` |  | ||||||
|         """ |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_id", |  | ||||||
|         help="Model ID of the original model to convert", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--output_dir", |  | ||||||
|         help="Location to write HF model and tokenizer", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", |  | ||||||
|         help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally", |  | ||||||
|         action="store_true", |  | ||||||
|         default=False, |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--revision", |  | ||||||
|         help="Revision of the model to download", |  | ||||||
|         default=None, |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--original_vlm_name_or_path", |  | ||||||
|         help="Name or path of the original VLM backbone model", |  | ||||||
|         default=None, |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_colpali_weights_to_hf( |  | ||||||
|         model_id=args.model_id, |  | ||||||
|         output_dir=args.output_dir, |  | ||||||
|         push_to_hub=args.push_to_hub, |  | ||||||
|         revision=args.revision, |  | ||||||
|         original_vlm_name_or_path=args.original_vlm_name_or_path, |  | ||||||
|     ) |  | ||||||
| @ -1,324 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert Conditional DETR checkpoints.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from collections import OrderedDict |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     ConditionalDetrConfig, |  | ||||||
|     ConditionalDetrForObjectDetection, |  | ||||||
|     ConditionalDetrForSegmentation, |  | ||||||
|     ConditionalDetrImageProcessor, |  | ||||||
| ) |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| # here we list all keys to be renamed (original name on the left, our name on the right) |  | ||||||
| rename_keys = [] |  | ||||||
| for i in range(6): |  | ||||||
|     # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight")) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias")) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight")) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")) |  | ||||||
|     rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias")) |  | ||||||
|     # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         ( |  | ||||||
|             f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight", |  | ||||||
|             f"decoder.layers.{i}.encoder_attn.out_proj.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         ( |  | ||||||
|             f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias", |  | ||||||
|             f"decoder.layers.{i}.encoder_attn.out_proj.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias")) |  | ||||||
|  |  | ||||||
|     # q, k, v projections in self/cross-attention in decoder for conditional DETR |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight") |  | ||||||
|     ) |  | ||||||
|     # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight") |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias") |  | ||||||
|     ) |  | ||||||
|     # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias") |  | ||||||
|     ) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias")) |  | ||||||
|     rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias")) |  | ||||||
|     rename_keys.append( |  | ||||||
|         (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias") |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
| # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads |  | ||||||
| # for conditional DETR, also convert reference point head and query scale MLP |  | ||||||
| rename_keys.extend( |  | ||||||
|     [ |  | ||||||
|         ("input_proj.weight", "input_projection.weight"), |  | ||||||
|         ("input_proj.bias", "input_projection.bias"), |  | ||||||
|         ("query_embed.weight", "query_position_embeddings.weight"), |  | ||||||
|         ("transformer.decoder.norm.weight", "decoder.layernorm.weight"), |  | ||||||
|         ("transformer.decoder.norm.bias", "decoder.layernorm.bias"), |  | ||||||
|         ("class_embed.weight", "class_labels_classifier.weight"), |  | ||||||
|         ("class_embed.bias", "class_labels_classifier.bias"), |  | ||||||
|         ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"), |  | ||||||
|         ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"), |  | ||||||
|         ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"), |  | ||||||
|         ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"), |  | ||||||
|         ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"), |  | ||||||
|         ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"), |  | ||||||
|         ("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"), |  | ||||||
|         ("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"), |  | ||||||
|         ("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"), |  | ||||||
|         ("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"), |  | ||||||
|         ("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"), |  | ||||||
|         ("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"), |  | ||||||
|         ("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"), |  | ||||||
|         ("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"), |  | ||||||
|         ("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"), |  | ||||||
|         ("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"), |  | ||||||
|     ] |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(state_dict, old, new): |  | ||||||
|     val = state_dict.pop(old) |  | ||||||
|     state_dict[new] = val |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_backbone_keys(state_dict): |  | ||||||
|     new_state_dict = OrderedDict() |  | ||||||
|     for key, value in state_dict.items(): |  | ||||||
|         if "backbone.0.body" in key: |  | ||||||
|             new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model") |  | ||||||
|             new_state_dict[new_key] = value |  | ||||||
|         else: |  | ||||||
|             new_state_dict[key] = value |  | ||||||
|  |  | ||||||
|     return new_state_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_in_q_k_v(state_dict, is_panoptic=False): |  | ||||||
|     prefix = "" |  | ||||||
|     if is_panoptic: |  | ||||||
|         prefix = "conditional_detr." |  | ||||||
|  |  | ||||||
|     # first: transformer encoder |  | ||||||
|     for i in range(6): |  | ||||||
|         # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) |  | ||||||
|         in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight") |  | ||||||
|         in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias") |  | ||||||
|         # next, add query, keys and values (in that order) to the state dict |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :] |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256] |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :] |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512] |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :] |  | ||||||
|         state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|  |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # load default config |  | ||||||
|     config = ConditionalDetrConfig() |  | ||||||
|     # set backbone and dilation attributes |  | ||||||
|     if "resnet101" in model_name: |  | ||||||
|         config.backbone = "resnet101" |  | ||||||
|     if "dc5" in model_name: |  | ||||||
|         config.dilation = True |  | ||||||
|     is_panoptic = "panoptic" in model_name |  | ||||||
|     if is_panoptic: |  | ||||||
|         config.num_labels = 250 |  | ||||||
|     else: |  | ||||||
|         config.num_labels = 91 |  | ||||||
|         repo_id = "huggingface/label-files" |  | ||||||
|         filename = "coco-detection-id2label.json" |  | ||||||
|         id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|         id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|         config.id2label = id2label |  | ||||||
|         config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     # load image processor |  | ||||||
|     format = "coco_panoptic" if is_panoptic else "coco_detection" |  | ||||||
|     image_processor = ConditionalDetrImageProcessor(format=format) |  | ||||||
|  |  | ||||||
|     # prepare image |  | ||||||
|     img = prepare_img() |  | ||||||
|     encoding = image_processor(images=img, return_tensors="pt") |  | ||||||
|     pixel_values = encoding["pixel_values"] |  | ||||||
|  |  | ||||||
|     logger.info(f"Converting model {model_name}...") |  | ||||||
|  |  | ||||||
|     # load original model from torch hub |  | ||||||
|     conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval() |  | ||||||
|     state_dict = conditional_detr.state_dict() |  | ||||||
|     # rename keys |  | ||||||
|     for src, dest in rename_keys: |  | ||||||
|         if is_panoptic: |  | ||||||
|             src = "conditional_detr." + src |  | ||||||
|         rename_key(state_dict, src, dest) |  | ||||||
|     state_dict = rename_backbone_keys(state_dict) |  | ||||||
|     # query, key and value matrices need special treatment |  | ||||||
|     read_in_q_k_v(state_dict, is_panoptic=is_panoptic) |  | ||||||
|     # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them |  | ||||||
|     prefix = "conditional_detr.model." if is_panoptic else "model." |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         if is_panoptic: |  | ||||||
|             if ( |  | ||||||
|                 key.startswith("conditional_detr") |  | ||||||
|                 and not key.startswith("class_labels_classifier") |  | ||||||
|                 and not key.startswith("bbox_predictor") |  | ||||||
|             ): |  | ||||||
|                 val = state_dict.pop(key) |  | ||||||
|                 state_dict["conditional_detr.model" + key[4:]] = val |  | ||||||
|             elif "class_labels_classifier" in key or "bbox_predictor" in key: |  | ||||||
|                 val = state_dict.pop(key) |  | ||||||
|                 state_dict["conditional_detr." + key] = val |  | ||||||
|             elif key.startswith("bbox_attention") or key.startswith("mask_head"): |  | ||||||
|                 continue |  | ||||||
|             else: |  | ||||||
|                 val = state_dict.pop(key) |  | ||||||
|                 state_dict[prefix + key] = val |  | ||||||
|         else: |  | ||||||
|             if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): |  | ||||||
|                 val = state_dict.pop(key) |  | ||||||
|                 state_dict[prefix + key] = val |  | ||||||
|     # finally, create HuggingFace model and load state dict |  | ||||||
|     model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config) |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|     model.eval() |  | ||||||
|     model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model") |  | ||||||
|     # verify our conversion |  | ||||||
|     original_outputs = conditional_detr(pixel_values) |  | ||||||
|     outputs = model(pixel_values) |  | ||||||
|     assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4) |  | ||||||
|     assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4) |  | ||||||
|     if is_panoptic: |  | ||||||
|         assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4) |  | ||||||
|  |  | ||||||
|     # Save model and image processor |  | ||||||
|     logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...") |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     image_processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|  |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="conditional_detr_resnet50", |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the CONDITIONAL_DETR model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,57 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2020 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert ConvBERT checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
|  |  | ||||||
| from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path): |  | ||||||
|     conf = ConvBertConfig.from_json_file(convbert_config_file) |  | ||||||
|     model = ConvBertModel(conf) |  | ||||||
|  |  | ||||||
|     model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path) |  | ||||||
|     model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|     tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True) |  | ||||||
|     tf_model.save_pretrained(pytorch_dump_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--convbert_config_file", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help=( |  | ||||||
|             "The config json file corresponding to the pre-trained ConvBERT model. \n" |  | ||||||
|             "This specifies the model architecture." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path) |  | ||||||
| @ -1,242 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert ConvNext checkpoints from the original repository. |  | ||||||
|  |  | ||||||
| URL: https://github.com/facebookresearch/ConvNeXt""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_convnext_config(checkpoint_url): |  | ||||||
|     config = ConvNextConfig() |  | ||||||
|  |  | ||||||
|     if "tiny" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 9, 3] |  | ||||||
|         hidden_sizes = [96, 192, 384, 768] |  | ||||||
|     if "small" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [96, 192, 384, 768] |  | ||||||
|     if "base" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [128, 256, 512, 1024] |  | ||||||
|     if "large" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [192, 384, 768, 1536] |  | ||||||
|     if "xlarge" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [256, 512, 1024, 2048] |  | ||||||
|  |  | ||||||
|     if "1k" in checkpoint_url: |  | ||||||
|         num_labels = 1000 |  | ||||||
|         filename = "imagenet-1k-id2label.json" |  | ||||||
|         expected_shape = (1, 1000) |  | ||||||
|     else: |  | ||||||
|         num_labels = 21841 |  | ||||||
|         filename = "imagenet-22k-id2label.json" |  | ||||||
|         expected_shape = (1, 21841) |  | ||||||
|  |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     config.num_labels = num_labels |  | ||||||
|     id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|     if "1k" not in checkpoint_url: |  | ||||||
|         # this dataset contains 21843 labels but the model only has 21841 |  | ||||||
|         # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 |  | ||||||
|         del id2label[9205] |  | ||||||
|         del id2label[15027] |  | ||||||
|     config.id2label = id2label |  | ||||||
|     config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|     config.hidden_sizes = hidden_sizes |  | ||||||
|     config.depths = depths |  | ||||||
|  |  | ||||||
|     return config, expected_shape |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     if "downsample_layers.0.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings") |  | ||||||
|     if "downsample_layers.0.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on |  | ||||||
|     if "downsample_layers.1.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.1.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1") |  | ||||||
|     if "downsample_layers.2.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.2.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1") |  | ||||||
|     if "downsample_layers.3.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.3.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1") |  | ||||||
|     if "stages" in name and "downsampling_layer" not in name: |  | ||||||
|         # stages.0.0. for instance should be renamed to stages.0.layers.0. |  | ||||||
|         name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :] |  | ||||||
|     if "stages" in name: |  | ||||||
|         name = name.replace("stages", "encoder.stages") |  | ||||||
|     if "norm" in name: |  | ||||||
|         name = name.replace("norm", "layernorm") |  | ||||||
|     if "gamma" in name: |  | ||||||
|         name = name.replace("gamma", "layer_scale_parameter") |  | ||||||
|     if "head" in name: |  | ||||||
|         name = name.replace("head", "classifier") |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our ConvNext structure. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     # define ConvNext configuration based on URL |  | ||||||
|     config, expected_shape = get_convnext_config(checkpoint_url) |  | ||||||
|     # load original state_dict from URL |  | ||||||
|     state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"] |  | ||||||
|     # rename keys |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         state_dict[rename_key(key)] = val |  | ||||||
|     # add prefix to all keys expect classifier head |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         if not key.startswith("classifier"): |  | ||||||
|             key = "convnext." + key |  | ||||||
|         state_dict[key] = val |  | ||||||
|  |  | ||||||
|     # load HuggingFace model |  | ||||||
|     model = ConvNextForImageClassification(config) |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     # Check outputs on an image, prepared by ConvNextImageProcessor |  | ||||||
|     size = 224 if "224" in checkpoint_url else 384 |  | ||||||
|     image_processor = ConvNextImageProcessor(size=size) |  | ||||||
|     pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values |  | ||||||
|  |  | ||||||
|     logits = model(pixel_values).logits |  | ||||||
|  |  | ||||||
|     # note: the logits below were obtained without center cropping |  | ||||||
|     if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([0.4525, 0.7539, 0.0308]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth": |  | ||||||
|         expected_logits = torch.tensor([0.3561, 0.6350, -0.0384]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([0.4174, -0.0989, 0.1489]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth": |  | ||||||
|         expected_logits = torch.tensor([0.2513, -0.1349, -0.1613]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth": |  | ||||||
|         expected_logits = torch.tensor([1.2980, 0.3631, -0.1198]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth": |  | ||||||
|         expected_logits = torch.tensor([1.2963, 0.1227, 0.1723]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth": |  | ||||||
|         expected_logits = torch.tensor([1.7956, 0.8390, 0.2820]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth": |  | ||||||
|         expected_logits = torch.tensor([0.2681, 0.2365, 0.6246]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth": |  | ||||||
|         expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444]) |  | ||||||
|     else: |  | ||||||
|         raise ValueError(f"Unknown URL: {checkpoint_url}") |  | ||||||
|  |  | ||||||
|     assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3) |  | ||||||
|     assert logits.shape == expected_shape |  | ||||||
|  |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     print(f"Saving model to {pytorch_dump_folder_path}") |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     print(f"Saving image processor to {pytorch_dump_folder_path}") |  | ||||||
|     image_processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     print("Pushing model to the hub...") |  | ||||||
|     model_name = "convnext" |  | ||||||
|     if "tiny" in checkpoint_url: |  | ||||||
|         model_name += "-tiny" |  | ||||||
|     elif "small" in checkpoint_url: |  | ||||||
|         model_name += "-small" |  | ||||||
|     elif "base" in checkpoint_url: |  | ||||||
|         model_name += "-base" |  | ||||||
|     elif "xlarge" in checkpoint_url: |  | ||||||
|         model_name += "-xlarge" |  | ||||||
|     elif "large" in checkpoint_url: |  | ||||||
|         model_name += "-large" |  | ||||||
|     if "224" in checkpoint_url: |  | ||||||
|         model_name += "-224" |  | ||||||
|     elif "384" in checkpoint_url: |  | ||||||
|         model_name += "-384" |  | ||||||
|     if "22k" in checkpoint_url and "1k" not in checkpoint_url: |  | ||||||
|         model_name += "-22k" |  | ||||||
|     if "22k" in checkpoint_url and "1k" in checkpoint_url: |  | ||||||
|         model_name += "-22k-1k" |  | ||||||
|  |  | ||||||
|     model.push_to_hub( |  | ||||||
|         repo_path_or_name=Path(pytorch_dump_folder_path, model_name), |  | ||||||
|         organization="nielsr", |  | ||||||
|         commit_message="Add model", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_url", |  | ||||||
|         default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", |  | ||||||
|         type=str, |  | ||||||
|         help="URL of the original ConvNeXT checkpoint you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default=None, |  | ||||||
|         type=str, |  | ||||||
|         required=True, |  | ||||||
|         help="Path to the output PyTorch model directory.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,286 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2023 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert ConvNeXTV2 checkpoints from the original repository. |  | ||||||
|  |  | ||||||
| URL: https://github.com/facebookresearch/ConvNeXt""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import requests |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
| from PIL import Image |  | ||||||
|  |  | ||||||
| from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification |  | ||||||
| from transformers.image_utils import PILImageResampling |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_convnextv2_config(checkpoint_url): |  | ||||||
|     config = ConvNextV2Config() |  | ||||||
|  |  | ||||||
|     if "atto" in checkpoint_url: |  | ||||||
|         depths = [2, 2, 6, 2] |  | ||||||
|         hidden_sizes = [40, 80, 160, 320] |  | ||||||
|     if "femto" in checkpoint_url: |  | ||||||
|         depths = [2, 2, 6, 2] |  | ||||||
|         hidden_sizes = [48, 96, 192, 384] |  | ||||||
|     if "pico" in checkpoint_url: |  | ||||||
|         depths = [2, 2, 6, 2] |  | ||||||
|         hidden_sizes = [64, 128, 256, 512] |  | ||||||
|     if "nano" in checkpoint_url: |  | ||||||
|         depths = [2, 2, 8, 2] |  | ||||||
|         hidden_sizes = [80, 160, 320, 640] |  | ||||||
|     if "tiny" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 9, 3] |  | ||||||
|         hidden_sizes = [96, 192, 384, 768] |  | ||||||
|     if "base" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [128, 256, 512, 1024] |  | ||||||
|     if "large" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [192, 384, 768, 1536] |  | ||||||
|     if "huge" in checkpoint_url: |  | ||||||
|         depths = [3, 3, 27, 3] |  | ||||||
|         hidden_sizes = [352, 704, 1408, 2816] |  | ||||||
|  |  | ||||||
|     num_labels = 1000 |  | ||||||
|     filename = "imagenet-1k-id2label.json" |  | ||||||
|     expected_shape = (1, 1000) |  | ||||||
|  |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     config.num_labels = num_labels |  | ||||||
|     id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     config.id2label = id2label |  | ||||||
|     config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|     config.hidden_sizes = hidden_sizes |  | ||||||
|     config.depths = depths |  | ||||||
|  |  | ||||||
|     return config, expected_shape |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def rename_key(name): |  | ||||||
|     if "downsample_layers.0.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings") |  | ||||||
|     if "downsample_layers.0.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on |  | ||||||
|     if "downsample_layers.1.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.1.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1") |  | ||||||
|     if "downsample_layers.2.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.2.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1") |  | ||||||
|     if "downsample_layers.3.0" in name: |  | ||||||
|         name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0") |  | ||||||
|     if "downsample_layers.3.1" in name: |  | ||||||
|         name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1") |  | ||||||
|     if "stages" in name and "downsampling_layer" not in name: |  | ||||||
|         # stages.0.0. for instance should be renamed to stages.0.layers.0. |  | ||||||
|         name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :] |  | ||||||
|     if "gamma" in name: |  | ||||||
|         name = name.replace("gamma", "weight") |  | ||||||
|     if "beta" in name: |  | ||||||
|         name = name.replace("beta", "bias") |  | ||||||
|     if "stages" in name: |  | ||||||
|         name = name.replace("stages", "encoder.stages") |  | ||||||
|     if "norm" in name: |  | ||||||
|         name = name.replace("norm", "layernorm") |  | ||||||
|     if "head" in name: |  | ||||||
|         name = name.replace("head", "classifier") |  | ||||||
|  |  | ||||||
|     return name |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # We will verify our results on an image of cute cats |  | ||||||
| def prepare_img(): |  | ||||||
|     url = "http://images.cocodataset.org/val2017/000000039769.jpg" |  | ||||||
|     im = Image.open(requests.get(url, stream=True).raw) |  | ||||||
|     return im |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_preprocessor(checkpoint_url): |  | ||||||
|     if "224" in checkpoint_url: |  | ||||||
|         size = 224 |  | ||||||
|         crop_pct = 224 / 256 |  | ||||||
|     elif "384" in checkpoint_url: |  | ||||||
|         size = 384 |  | ||||||
|         crop_pct = None |  | ||||||
|     else: |  | ||||||
|         size = 512 |  | ||||||
|         crop_pct = None |  | ||||||
|  |  | ||||||
|     return ConvNextImageProcessor( |  | ||||||
|         size=size, |  | ||||||
|         crop_pct=crop_pct, |  | ||||||
|         image_mean=[0.485, 0.456, 0.406], |  | ||||||
|         image_std=[0.229, 0.224, 0.225], |  | ||||||
|         resample=PILImageResampling.BICUBIC, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to our ConvNeXTV2 structure. |  | ||||||
|     """ |  | ||||||
|     print("Downloading original model from checkpoint...") |  | ||||||
|     # define ConvNeXTV2 configuration based on URL |  | ||||||
|     config, expected_shape = get_convnextv2_config(checkpoint_url) |  | ||||||
|     # load original state_dict from URL |  | ||||||
|     state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"] |  | ||||||
|  |  | ||||||
|     print("Converting model parameters...") |  | ||||||
|     # rename keys |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         state_dict[rename_key(key)] = val |  | ||||||
|     # add prefix to all keys expect classifier head |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         val = state_dict.pop(key) |  | ||||||
|         if not key.startswith("classifier"): |  | ||||||
|             key = "convnextv2." + key |  | ||||||
|         state_dict[key] = val |  | ||||||
|  |  | ||||||
|     # load HuggingFace model |  | ||||||
|     model = ConvNextV2ForImageClassification(config) |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|     model.eval() |  | ||||||
|  |  | ||||||
|     # Check outputs on an image, prepared by ConvNextImageProcessor |  | ||||||
|     preprocessor = convert_preprocessor(checkpoint_url) |  | ||||||
|     inputs = preprocessor(images=prepare_img(), return_tensors="pt") |  | ||||||
|     logits = model(**inputs).logits |  | ||||||
|  |  | ||||||
|     # note: the logits below were obtained without center cropping |  | ||||||
|     if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309]) |  | ||||||
|     elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt": |  | ||||||
|         expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826]) |  | ||||||
|     else: |  | ||||||
|         raise ValueError(f"Unknown URL: {checkpoint_url}") |  | ||||||
|  |  | ||||||
|     assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3) |  | ||||||
|     assert logits.shape == expected_shape |  | ||||||
|     print("Model outputs match the original results!") |  | ||||||
|  |  | ||||||
|     if save_model: |  | ||||||
|         print("Saving model to local...") |  | ||||||
|         # Create folder to save model |  | ||||||
|         if not os.path.isdir(pytorch_dump_folder_path): |  | ||||||
|             os.mkdir(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|         model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|         preprocessor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     model_name = "convnextv2" |  | ||||||
|     if "atto" in checkpoint_url: |  | ||||||
|         model_name += "-atto" |  | ||||||
|     if "femto" in checkpoint_url: |  | ||||||
|         model_name += "-femto" |  | ||||||
|     if "pico" in checkpoint_url: |  | ||||||
|         model_name += "-pico" |  | ||||||
|     if "nano" in checkpoint_url: |  | ||||||
|         model_name += "-nano" |  | ||||||
|     elif "tiny" in checkpoint_url: |  | ||||||
|         model_name += "-tiny" |  | ||||||
|     elif "base" in checkpoint_url: |  | ||||||
|         model_name += "-base" |  | ||||||
|     elif "large" in checkpoint_url: |  | ||||||
|         model_name += "-large" |  | ||||||
|     elif "huge" in checkpoint_url: |  | ||||||
|         model_name += "-huge" |  | ||||||
|     if "22k" in checkpoint_url and "1k" not in checkpoint_url: |  | ||||||
|         model_name += "-22k" |  | ||||||
|     elif "22k" in checkpoint_url and "1k" in checkpoint_url: |  | ||||||
|         model_name += "-22k-1k" |  | ||||||
|     elif "1k" in checkpoint_url: |  | ||||||
|         model_name += "-1k" |  | ||||||
|     if "224" in checkpoint_url: |  | ||||||
|         model_name += "-224" |  | ||||||
|     elif "384" in checkpoint_url: |  | ||||||
|         model_name += "-384" |  | ||||||
|     elif "512" in checkpoint_url: |  | ||||||
|         model_name += "-512" |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         print(f"Pushing {model_name} to the hub...") |  | ||||||
|         model.push_to_hub(model_name) |  | ||||||
|         preprocessor.push_to_hub(model_name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     # Required parameters |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--checkpoint_url", |  | ||||||
|         default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt", |  | ||||||
|         type=str, |  | ||||||
|         help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", |  | ||||||
|         default="model", |  | ||||||
|         type=str, |  | ||||||
|         help="Path to the output PyTorch model directory.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--save_model", action="store_true", help="Save model to local") |  | ||||||
|     parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub") |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_convnextv2_checkpoint( |  | ||||||
|         args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub |  | ||||||
|     ) |  | ||||||
| @ -1,362 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2022 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert CvT checkpoints from the original repository. |  | ||||||
|  |  | ||||||
| URL: https://github.com/microsoft/CvT""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import json |  | ||||||
| from collections import OrderedDict |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
|  |  | ||||||
| from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def embeddings(idx): |  | ||||||
|     """ |  | ||||||
|     The function helps in renaming embedding layer weights. |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         idx: stage number in original model |  | ||||||
|     """ |  | ||||||
|     embed = [] |  | ||||||
|     embed.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight", |  | ||||||
|             f"stage{idx}.patch_embed.proj.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     embed.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias", |  | ||||||
|             f"stage{idx}.patch_embed.proj.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     embed.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight", |  | ||||||
|             f"stage{idx}.patch_embed.norm.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     embed.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias", |  | ||||||
|             f"stage{idx}.patch_embed.norm.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     return embed |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def attention(idx, cnt): |  | ||||||
|     """ |  | ||||||
|     The function helps in renaming attention block layers weights. |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         idx: stage number in original model |  | ||||||
|         cnt: count of blocks in each stage |  | ||||||
|     """ |  | ||||||
|     attention_weights = [] |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_q.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_q.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_k.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_k.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_v.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj_v.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj.weight", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         ( |  | ||||||
|             f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias", |  | ||||||
|             f"stage{idx}.blocks.{cnt}.attn.proj.bias", |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight") |  | ||||||
|     ) |  | ||||||
|     attention_weights.append( |  | ||||||
|         (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias") |  | ||||||
|     ) |  | ||||||
|     return attention_weights |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def cls_token(idx): |  | ||||||
|     """ |  | ||||||
|     Function helps in renaming cls_token weights |  | ||||||
|     """ |  | ||||||
|     token = [] |  | ||||||
|     token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token")) |  | ||||||
|     return token |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def final(): |  | ||||||
|     """ |  | ||||||
|     Function helps in renaming final classification layer |  | ||||||
|     """ |  | ||||||
|     head = [] |  | ||||||
|     head.append(("layernorm.weight", "norm.weight")) |  | ||||||
|     head.append(("layernorm.bias", "norm.bias")) |  | ||||||
|     head.append(("classifier.weight", "head.weight")) |  | ||||||
|     head.append(("classifier.bias", "head.bias")) |  | ||||||
|     return head |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder): |  | ||||||
|     """ |  | ||||||
|     Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint |  | ||||||
|     """ |  | ||||||
|     img_labels_file = "imagenet-1k-id2label.json" |  | ||||||
|     num_labels = 1000 |  | ||||||
|  |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     num_labels = num_labels |  | ||||||
|     id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text()) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     id2label = id2label |  | ||||||
|     label2id = {v: k for k, v in id2label.items()} |  | ||||||
|  |  | ||||||
|     config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id) |  | ||||||
|  |  | ||||||
|     # For depth size 13 (13 = 1+2+10) |  | ||||||
|     if cvt_model.rsplit("/", 1)[-1][4:6] == "13": |  | ||||||
|         config.depth = [1, 2, 10] |  | ||||||
|  |  | ||||||
|     # For depth size 21 (21 = 1+4+16) |  | ||||||
|     elif cvt_model.rsplit("/", 1)[-1][4:6] == "21": |  | ||||||
|         config.depth = [1, 4, 16] |  | ||||||
|  |  | ||||||
|     # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20) |  | ||||||
|     else: |  | ||||||
|         config.depth = [2, 2, 20] |  | ||||||
|         config.num_heads = [3, 12, 16] |  | ||||||
|         config.embed_dim = [192, 768, 1024] |  | ||||||
|  |  | ||||||
|     model = CvtForImageClassification(config) |  | ||||||
|     image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k") |  | ||||||
|     image_processor.size["shortest_edge"] = image_size |  | ||||||
|     original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu")) |  | ||||||
|  |  | ||||||
|     huggingface_weights = OrderedDict() |  | ||||||
|     list_of_state_dict = [] |  | ||||||
|  |  | ||||||
|     for idx in range(len(config.depth)): |  | ||||||
|         if config.cls_token[idx]: |  | ||||||
|             list_of_state_dict = list_of_state_dict + cls_token(idx) |  | ||||||
|         list_of_state_dict = list_of_state_dict + embeddings(idx) |  | ||||||
|         for cnt in range(config.depth[idx]): |  | ||||||
|             list_of_state_dict = list_of_state_dict + attention(idx, cnt) |  | ||||||
|  |  | ||||||
|     list_of_state_dict = list_of_state_dict + final() |  | ||||||
|     for gg in list_of_state_dict: |  | ||||||
|         print(gg) |  | ||||||
|     for i in range(len(list_of_state_dict)): |  | ||||||
|         huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]] |  | ||||||
|  |  | ||||||
|     model.load_state_dict(huggingface_weights) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder) |  | ||||||
|     image_processor.save_pretrained(pytorch_dump_folder) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--cvt_model", |  | ||||||
|         default="cvt-w24", |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the cvt model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--image_size", |  | ||||||
|         default=384, |  | ||||||
|         type=int, |  | ||||||
|         help="Input Image Size", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--cvt_file_name", |  | ||||||
|         default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth", |  | ||||||
|         type=str, |  | ||||||
|         help="Input Image Size", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path) |  | ||||||
| @ -1,233 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2024 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert DAB-DETR checkpoints.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import gc |  | ||||||
| import json |  | ||||||
| import re |  | ||||||
| from pathlib import Path |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download |  | ||||||
|  |  | ||||||
| from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection |  | ||||||
| from transformers.utils import logging |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| ORIGINAL_TO_CONVERTED_KEY_MAPPING = { |  | ||||||
|     # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads |  | ||||||
|     # for dab-DETR, also convert reference point head and query scale MLP |  | ||||||
|     r"input_proj\.(bias|weight)": r"input_projection.\1", |  | ||||||
|     r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight", |  | ||||||
|     r"class_embed\.(bias|weight)": r"class_embed.\1", |  | ||||||
|     # negative lookbehind because of the overlap |  | ||||||
|     r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2", |  | ||||||
|     r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2", |  | ||||||
|     r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2", |  | ||||||
|     r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1", |  | ||||||
|     r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2", |  | ||||||
|     r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2", |  | ||||||
|     r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1", |  | ||||||
|     # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function |  | ||||||
|     # output projection |  | ||||||
|     r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2", |  | ||||||
|     # FFN layers |  | ||||||
|     r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3", |  | ||||||
|     # normalization layers |  | ||||||
|     # nm1 |  | ||||||
|     r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2", |  | ||||||
|     # nm2 |  | ||||||
|     r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2", |  | ||||||
|     # activation function weight |  | ||||||
|     r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight", |  | ||||||
|     ######################################################################################################################################### |  | ||||||
|     # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2", |  | ||||||
|     # FFNs |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3", |  | ||||||
|     # nm1 |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2", |  | ||||||
|     # nm2 |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2", |  | ||||||
|     # nm3 |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2", |  | ||||||
|     # activation function weight |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight", |  | ||||||
|     # q, k, v projections and biases in self-attention in decoder |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2", |  | ||||||
|     # q, k, v projections in cross-attention in decoder |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2", |  | ||||||
|     r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2", |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys |  | ||||||
| def convert_old_keys_to_new_keys(state_dict_keys: dict = None): |  | ||||||
|     """ |  | ||||||
|     This function should be applied only once, on the concatenated keys to efficiently rename using |  | ||||||
|     the key mappings. |  | ||||||
|     """ |  | ||||||
|     output_dict = {} |  | ||||||
|     if state_dict_keys is not None: |  | ||||||
|         old_text = "\n".join(state_dict_keys) |  | ||||||
|         new_text = old_text |  | ||||||
|         for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): |  | ||||||
|             if replacement is None: |  | ||||||
|                 new_text = re.sub(pattern, "", new_text)  # an empty line |  | ||||||
|                 continue |  | ||||||
|             new_text = re.sub(pattern, replacement, new_text) |  | ||||||
|         output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) |  | ||||||
|     return output_dict |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub): |  | ||||||
|     logger.info("Converting image processor...") |  | ||||||
|     format = "coco_detection" |  | ||||||
|     image_processor = ConditionalDetrImageProcessor(format=format) |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     image_processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub): |  | ||||||
|     # load modified config. Why? After loading the default config, the backbone kwargs are already set. |  | ||||||
|     if "dc5" in model_name: |  | ||||||
|         config = DabDetrConfig(dilation=True) |  | ||||||
|     else: |  | ||||||
|         # load default config |  | ||||||
|         config = DabDetrConfig() |  | ||||||
|     # set other attributes |  | ||||||
|     if "dab-detr-resnet-50-dc5" == model_name: |  | ||||||
|         config.temperature_height = 10 |  | ||||||
|         config.temperature_width = 10 |  | ||||||
|     if "fixxy" in model_name: |  | ||||||
|         config.random_refpoints_xy = True |  | ||||||
|     if "pat3" in model_name: |  | ||||||
|         config.num_patterns = 3 |  | ||||||
|         # only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3 |  | ||||||
|         ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"}) |  | ||||||
|  |  | ||||||
|     config.num_labels = 91 |  | ||||||
|     repo_id = "huggingface/label-files" |  | ||||||
|     filename = "coco-detection-id2label.json" |  | ||||||
|     id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |  | ||||||
|     id2label = {int(k): v for k, v in id2label.items()} |  | ||||||
|     config.id2label = id2label |  | ||||||
|     config.label2id = {v: k for k, v in id2label.items()} |  | ||||||
|     # load original model from local path |  | ||||||
|     loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"))["model"] |  | ||||||
|     # Renaming the original model state dictionary to HF compatibile |  | ||||||
|     all_keys = list(loaded.keys()) |  | ||||||
|     new_keys = convert_old_keys_to_new_keys(all_keys) |  | ||||||
|     state_dict = {} |  | ||||||
|     for key in all_keys: |  | ||||||
|         if "backbone.0.body" in key: |  | ||||||
|             new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone") |  | ||||||
|             state_dict[new_key] = loaded[key] |  | ||||||
|         # Q, K, V encoder values mapping |  | ||||||
|         elif re.search("self_attn.in_proj_(weight|bias)", key): |  | ||||||
|             # Dynamically find the layer number |  | ||||||
|             pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)" |  | ||||||
|             match = re.search(pattern, key) |  | ||||||
|             if match: |  | ||||||
|                 layer_num = match.group(1) |  | ||||||
|             else: |  | ||||||
|                 raise ValueError(f"Pattern not found in key: {key}") |  | ||||||
|  |  | ||||||
|             in_proj_value = loaded.pop(key) |  | ||||||
|             if "weight" in key: |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :] |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :] |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :] |  | ||||||
|             elif "bias" in key: |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256] |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512] |  | ||||||
|                 state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:] |  | ||||||
|         else: |  | ||||||
|             new_key = new_keys[key] |  | ||||||
|             state_dict[new_key] = loaded[key] |  | ||||||
|  |  | ||||||
|     del loaded |  | ||||||
|     gc.collect() |  | ||||||
|     # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them |  | ||||||
|     prefix = "model." |  | ||||||
|     for key in state_dict.copy().keys(): |  | ||||||
|         if not key.startswith("class_embed") and not key.startswith("bbox_predictor"): |  | ||||||
|             val = state_dict.pop(key) |  | ||||||
|             state_dict[prefix + key] = val |  | ||||||
|     # finally, create HuggingFace model and load state dict |  | ||||||
|     model = DabDetrForObjectDetection(config) |  | ||||||
|     model.load_state_dict(state_dict) |  | ||||||
|     model.eval() |  | ||||||
|     logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...") |  | ||||||
|     Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if push_to_hub: |  | ||||||
|         model.push_to_hub(repo_id=model_name, commit_message="Add new model") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub): |  | ||||||
|     logger.info("Converting image processor...") |  | ||||||
|     write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub) |  | ||||||
|  |  | ||||||
|     logger.info(f"Converting model {model_name}...") |  | ||||||
|     write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|  |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model_name", |  | ||||||
|         default="dab-detr-resnet-50", |  | ||||||
|         type=str, |  | ||||||
|         help="Name of the DAB_DETR model you'd like to convert.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pretrained_model_weights_path", |  | ||||||
|         default="modelzoo/R50/checkpoint.pth", |  | ||||||
|         type=str, |  | ||||||
|         help="The path of the original model weights like: modelzoo/checkpoint.pth", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", |  | ||||||
|         default=True, |  | ||||||
|         type=bool, |  | ||||||
|         help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.", |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_dab_detr_checkpoint( |  | ||||||
|         args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub |  | ||||||
|     ) |  | ||||||
| @ -1,261 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| import argparse |  | ||||||
| import fnmatch |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
|  |  | ||||||
| from transformers import ( |  | ||||||
|     DacConfig, |  | ||||||
|     DacFeatureExtractor, |  | ||||||
|     DacModel, |  | ||||||
|     logging, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # checkpoints downloaded using: |  | ||||||
| # pip install descript-audio-codec |  | ||||||
| # python3 -m dac download # downloads the default 44kHz variant |  | ||||||
| # python3 -m dac download --model_type 44khz # downloads the 44kHz variant |  | ||||||
| # python3 -m dac download --model_type 24khz # downloads the 24kHz variant |  | ||||||
| # python3 -m dac download --model_type 16khz # downloads the 16kHz variant |  | ||||||
| # More informations: https://github.com/descriptinc/descript-audio-codec/tree/main |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger("transformers.models.dac") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def match_pattern(string, pattern): |  | ||||||
|     # Split the pattern into parts |  | ||||||
|     pattern_parts = pattern.split(".") |  | ||||||
|     string_parts = string.split(".") |  | ||||||
|  |  | ||||||
|     pattern_block_count = string_block_count = 0 |  | ||||||
|  |  | ||||||
|     for part in pattern_parts: |  | ||||||
|         if part.startswith("block"): |  | ||||||
|             pattern_block_count += 1 |  | ||||||
|  |  | ||||||
|     for part in string_parts: |  | ||||||
|         if part.startswith("block"): |  | ||||||
|             string_block_count += 1 |  | ||||||
|  |  | ||||||
|     return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TOP_LEVEL_KEYS = [] |  | ||||||
| IGNORE_KEYS = [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MAPPING_ENCODER = { |  | ||||||
|     "encoder.block.0": ["encoder.conv1"], |  | ||||||
|     "encoder.block.5": ["encoder.snake1"], |  | ||||||
|     "encoder.block.6": ["encoder.conv2"], |  | ||||||
|     "encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"], |  | ||||||
|     "encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"], |  | ||||||
|     "encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"], |  | ||||||
|     "encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"], |  | ||||||
|     "encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"], |  | ||||||
|     "encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"], |  | ||||||
| } |  | ||||||
|  |  | ||||||
| MAPPING_QUANTIZER = { |  | ||||||
|     "quantizer.quantizers.*": ["quantizer.quantizers.*"], |  | ||||||
| } |  | ||||||
|  |  | ||||||
| MAPPING_DECODER = { |  | ||||||
|     "decoder.model.0": ["decoder.conv1"], |  | ||||||
|     "decoder.model.5": ["decoder.snake1"], |  | ||||||
|     "decoder.model.6": ["decoder.conv2"], |  | ||||||
|     "decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"], |  | ||||||
|     "decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"], |  | ||||||
|     "decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"], |  | ||||||
|     "decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"], |  | ||||||
|     "decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"], |  | ||||||
|     "decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"], |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MAPPING = { |  | ||||||
|     **MAPPING_ENCODER, |  | ||||||
|     **MAPPING_QUANTIZER, |  | ||||||
|     **MAPPING_DECODER, |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def set_recursively(hf_pointer, key, value, full_name, weight_type): |  | ||||||
|     for attribute in key.split("."): |  | ||||||
|         hf_pointer = getattr(hf_pointer, attribute) |  | ||||||
|  |  | ||||||
|     if weight_type is not None: |  | ||||||
|         hf_shape = getattr(hf_pointer, weight_type).shape |  | ||||||
|     else: |  | ||||||
|         hf_shape = hf_pointer.shape |  | ||||||
|  |  | ||||||
|     if hf_shape != value.shape: |  | ||||||
|         raise ValueError( |  | ||||||
|             f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" |  | ||||||
|             f" {value.shape} for {full_name}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if weight_type == "weight": |  | ||||||
|         hf_pointer.weight.data = value |  | ||||||
|     elif weight_type == "weight_g": |  | ||||||
|         hf_pointer.weight_g.data = value |  | ||||||
|     elif weight_type == "weight_v": |  | ||||||
|         hf_pointer.weight_v.data = value |  | ||||||
|     elif weight_type == "bias": |  | ||||||
|         hf_pointer.bias.data = value |  | ||||||
|     elif weight_type == "alpha": |  | ||||||
|         hf_pointer.alpha.data = value |  | ||||||
|     logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def should_ignore(name, ignore_keys): |  | ||||||
|     for key in ignore_keys: |  | ||||||
|         if key.endswith(".*"): |  | ||||||
|             if name.startswith(key[:-1]): |  | ||||||
|                 return True |  | ||||||
|         elif ".*." in key: |  | ||||||
|             prefix, suffix = key.split(".*.") |  | ||||||
|             if prefix in name and suffix in name: |  | ||||||
|                 return True |  | ||||||
|         elif key in name: |  | ||||||
|             return True |  | ||||||
|     return False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def recursively_load_weights(orig_dict, hf_model, model_name): |  | ||||||
|     unused_weights = [] |  | ||||||
|  |  | ||||||
|     if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]: |  | ||||||
|         raise ValueError(f"Unsupported model: {model_name}") |  | ||||||
|  |  | ||||||
|     for name, value in orig_dict.items(): |  | ||||||
|         is_used = False |  | ||||||
|         for key, mapped_key in MAPPING.items(): |  | ||||||
|             regex = re.compile(key) |  | ||||||
|             if regex.search(name): |  | ||||||
|                 if len(mapped_key) == 1: |  | ||||||
|                     if mapped_key[0][0] == "q": |  | ||||||
|                         mapped_key = ".".join(name.split(".")[:-1]) |  | ||||||
|                     else: |  | ||||||
|                         mapped_key = mapped_key[0] |  | ||||||
|                 elif len(mapped_key) == 3: |  | ||||||
|                     integers = re.findall(r"\b\d+\b", name) |  | ||||||
|                     if mapped_key[0][0] == "d": |  | ||||||
|                         mapped_key = "{}.{}.{}{}.{}".format( |  | ||||||
|                             mapped_key[0], |  | ||||||
|                             str(int(integers[0]) - 1), |  | ||||||
|                             mapped_key[1], |  | ||||||
|                             str(int(integers[1]) - 1), |  | ||||||
|                             mapped_key[2], |  | ||||||
|                         ) |  | ||||||
|                     else: |  | ||||||
|                         mapped_key = "{}.{}.{}{}.{}".format( |  | ||||||
|                             mapped_key[0], |  | ||||||
|                             str(int(integers[0]) - 1), |  | ||||||
|                             mapped_key[1], |  | ||||||
|                             str(int(integers[1]) + 1), |  | ||||||
|                             mapped_key[2], |  | ||||||
|                         ) |  | ||||||
|                 elif len(mapped_key) == 2: |  | ||||||
|                     integers = re.findall(r"\b\d+\b", name) |  | ||||||
|                     mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1]) |  | ||||||
|  |  | ||||||
|                 is_used = True |  | ||||||
|                 if "weight_g" in name: |  | ||||||
|                     weight_type = "weight_g" |  | ||||||
|                 elif "weight_v" in name: |  | ||||||
|                     weight_type = "weight_v" |  | ||||||
|                 elif "bias" in name: |  | ||||||
|                     weight_type = "bias" |  | ||||||
|                 elif "alpha" in name: |  | ||||||
|                     weight_type = "alpha" |  | ||||||
|                 elif "weight" in name: |  | ||||||
|                     weight_type = "weight" |  | ||||||
|                 set_recursively(hf_model, mapped_key, value, name, weight_type) |  | ||||||
|  |  | ||||||
|         if not is_used: |  | ||||||
|             unused_weights.append(name) |  | ||||||
|  |  | ||||||
|     print(list(set(unused_weights))) |  | ||||||
|  |  | ||||||
|     logger.warning(f"Unused weights: {unused_weights}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_checkpoint( |  | ||||||
|     model_name, |  | ||||||
|     checkpoint_path, |  | ||||||
|     pytorch_dump_folder_path, |  | ||||||
|     sample_rate=16000, |  | ||||||
|     repo_id=None, |  | ||||||
| ): |  | ||||||
|     model_dict = torch.load(checkpoint_path, "cpu") |  | ||||||
|  |  | ||||||
|     config = DacConfig() |  | ||||||
|  |  | ||||||
|     metadata = model_dict["metadata"]["kwargs"] |  | ||||||
|     config.encoder_hidden_size = metadata["encoder_dim"] |  | ||||||
|     config.downsampling_ratios = metadata["encoder_rates"] |  | ||||||
|     config.codebook_size = metadata["codebook_size"] |  | ||||||
|     config.n_codebooks = metadata["n_codebooks"] |  | ||||||
|     config.codebook_dim = metadata["codebook_dim"] |  | ||||||
|     config.decoder_hidden_size = metadata["decoder_dim"] |  | ||||||
|     config.upsampling_ratios = metadata["decoder_rates"] |  | ||||||
|     config.quantizer_dropout = float(metadata["quantizer_dropout"]) |  | ||||||
|     config.sampling_rate = sample_rate |  | ||||||
|  |  | ||||||
|     model = DacModel(config) |  | ||||||
|     feature_extractor = DacFeatureExtractor() |  | ||||||
|     feature_extractor.sampling_rate = sample_rate |  | ||||||
|  |  | ||||||
|     original_checkpoint = model_dict["state_dict"] |  | ||||||
|  |  | ||||||
|     model.apply_weight_norm() |  | ||||||
|     recursively_load_weights(original_checkpoint, model, model_name) |  | ||||||
|     model.remove_weight_norm() |  | ||||||
|  |  | ||||||
|     model.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if repo_id: |  | ||||||
|         print("Pushing to the hub...") |  | ||||||
|         feature_extractor.push_to_hub(repo_id) |  | ||||||
|         model.push_to_hub(repo_id) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--model", |  | ||||||
|         default="dac_44khz", |  | ||||||
|         type=str, |  | ||||||
|         help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." |  | ||||||
|     ) |  | ||||||
|     parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor") |  | ||||||
|     args = parser.parse_args() |  | ||||||
|  |  | ||||||
|     convert_checkpoint( |  | ||||||
|         args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub |  | ||||||
|     ) |  | ||||||
| @ -1,285 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Inc. team. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
| """Convert Wav2Vec2 checkpoint.""" |  | ||||||
|  |  | ||||||
| import argparse |  | ||||||
| import os |  | ||||||
| from functools import reduce |  | ||||||
|  |  | ||||||
| import fairseq |  | ||||||
| import torch |  | ||||||
| from datasets import load_dataset |  | ||||||
|  |  | ||||||
| from transformers import Wav2Vec2Processor, logging |  | ||||||
| from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig |  | ||||||
|  |  | ||||||
| # Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py |  | ||||||
| from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy  # noqa: F401 |  | ||||||
| from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| logging.set_verbosity_info() |  | ||||||
| logger = logging.get_logger(__name__) |  | ||||||
|  |  | ||||||
| MAPPING = { |  | ||||||
|     "post_extract_proj": "feature_projection.projection", |  | ||||||
|     "models.0.layer_norm": "feature_projection.layer_norm", |  | ||||||
|     "self_attn.k_proj": "encoder.layers.*.attention.k_proj", |  | ||||||
|     "self_attn.v_proj": "encoder.layers.*.attention.v_proj", |  | ||||||
|     "self_attn.q_proj": "encoder.layers.*.attention.q_proj", |  | ||||||
|     "self_attn.out_proj": "encoder.layers.*.attention.out_proj", |  | ||||||
|     "self_attn_layer_norm": "encoder.layers.*.layer_norm", |  | ||||||
|     "fc1": "encoder.layers.*.feed_forward.intermediate_dense", |  | ||||||
|     "fc2": "encoder.layers.*.feed_forward.output_dense", |  | ||||||
|     "final_layer_norm": "encoder.layers.*.final_layer_norm", |  | ||||||
|     "encoder.layer_norm": "encoder.layer_norm", |  | ||||||
|     "w2v_model.layer_norm": "feature_projection.layer_norm", |  | ||||||
|     "w2v_encoder.proj": "lm_head", |  | ||||||
|     "mask_emb": "masked_spec_embed", |  | ||||||
| } |  | ||||||
| TOP_LEVEL_KEYS = [ |  | ||||||
|     "lm_head", |  | ||||||
| ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def set_recursively(hf_pointer, key, value, full_name, weight_type): |  | ||||||
|     for attribute in key.split("."): |  | ||||||
|         hf_pointer = getattr(hf_pointer, attribute) |  | ||||||
|  |  | ||||||
|     if weight_type is not None: |  | ||||||
|         hf_shape = getattr(hf_pointer, weight_type).shape |  | ||||||
|     else: |  | ||||||
|         hf_shape = hf_pointer.shape |  | ||||||
|  |  | ||||||
|     if hf_shape != value.shape: |  | ||||||
|         raise ValueError( |  | ||||||
|             f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" |  | ||||||
|             f" {value.shape} for {full_name}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if weight_type == "weight": |  | ||||||
|         hf_pointer.weight.data = value |  | ||||||
|     elif weight_type == "weight_g": |  | ||||||
|         hf_pointer.weight_g.data = value |  | ||||||
|     elif weight_type == "weight_v": |  | ||||||
|         hf_pointer.weight_v.data = value |  | ||||||
|     elif weight_type == "bias": |  | ||||||
|         hf_pointer.bias.data = value |  | ||||||
|     else: |  | ||||||
|         hf_pointer.data = value |  | ||||||
|  |  | ||||||
|     logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def recursively_load_weights(fairseq_model, hf_model, is_headless): |  | ||||||
|     unused_weights = [] |  | ||||||
|     fairseq_dict = fairseq_model.state_dict() |  | ||||||
|  |  | ||||||
|     if not is_headless: |  | ||||||
|         feature_extractor = hf_model.data2vec_audio.feature_extractor |  | ||||||
|         pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed |  | ||||||
|  |  | ||||||
|     else: |  | ||||||
|         feature_extractor = hf_model.feature_extractor |  | ||||||
|         pos_conv_embedding = hf_model.encoder.pos_conv_embed |  | ||||||
|  |  | ||||||
|     for name, value in fairseq_dict.items(): |  | ||||||
|         is_used = False |  | ||||||
|         if "conv_layers" in name: |  | ||||||
|             load_conv_layer( |  | ||||||
|                 name, |  | ||||||
|                 value, |  | ||||||
|                 feature_extractor, |  | ||||||
|                 unused_weights, |  | ||||||
|             ) |  | ||||||
|             is_used = True |  | ||||||
|         elif "pos_conv" in name: |  | ||||||
|             load_pos_conv_layer( |  | ||||||
|                 name, |  | ||||||
|                 value, |  | ||||||
|                 pos_conv_embedding, |  | ||||||
|                 unused_weights, |  | ||||||
|             ) |  | ||||||
|             is_used = True |  | ||||||
|         else: |  | ||||||
|             for key, mapped_key in MAPPING.items(): |  | ||||||
|                 if not is_headless: |  | ||||||
|                     mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key |  | ||||||
|                 if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: |  | ||||||
|                     is_used = True |  | ||||||
|                     if "*" in mapped_key: |  | ||||||
|                         layer_index = name.split(key)[0].split(".")[-2] |  | ||||||
|                         mapped_key = mapped_key.replace("*", layer_index) |  | ||||||
|                     if "weight_g" in name: |  | ||||||
|                         weight_type = "weight_g" |  | ||||||
|                     elif "weight_v" in name: |  | ||||||
|                         weight_type = "weight_v" |  | ||||||
|                     elif "bias" in name: |  | ||||||
|                         weight_type = "bias" |  | ||||||
|                     elif "weight" in name: |  | ||||||
|                         # TODO: don't match quantizer.weight_proj |  | ||||||
|                         weight_type = "weight" |  | ||||||
|                     else: |  | ||||||
|                         weight_type = None |  | ||||||
|                     set_recursively(hf_model, mapped_key, value, name, weight_type) |  | ||||||
|                 continue |  | ||||||
|         if not is_used: |  | ||||||
|             unused_weights.append(name) |  | ||||||
|  |  | ||||||
|     logger.warning(f"Unused weights: {unused_weights}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def access_by_string(module, path): |  | ||||||
|     names = path.split(".") |  | ||||||
|     return reduce(getattr, names, module) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def set_weights(full_name, module, fsq_value, hf_weight_path): |  | ||||||
|     hf_weight = access_by_string(module, hf_weight_path) |  | ||||||
|     hf_value = hf_weight.data |  | ||||||
|  |  | ||||||
|     if fsq_value.shape != hf_value.shape: |  | ||||||
|         raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.") |  | ||||||
|     hf_weight.data = fsq_value |  | ||||||
|     logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_conv_layer(full_name, value, feature_extractor, unused_weights): |  | ||||||
|     name = full_name.split("conv_layers.")[-1] |  | ||||||
|     items = name.split(".") |  | ||||||
|     layer_id = int(items[0]) |  | ||||||
|     type_id = int(items[1]) |  | ||||||
|  |  | ||||||
|     weight_type = name.split(".")[-1] |  | ||||||
|     if type_id == 0: |  | ||||||
|         layer_type = "conv" |  | ||||||
|     elif type_id == 2: |  | ||||||
|         layer_type = "layer_norm" |  | ||||||
|     else: |  | ||||||
|         unused_weights.append(full_name) |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|     set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights): |  | ||||||
|     name = full_name.split("pos_conv.")[-1] |  | ||||||
|     items = name.split(".") |  | ||||||
|     layer_id = int(items[0]) |  | ||||||
|     type_id = int(items[1]) |  | ||||||
|  |  | ||||||
|     weight_type = name.split(".")[-1] |  | ||||||
|     if type_id != 0: |  | ||||||
|         unused_weights.append(full_name) |  | ||||||
|         return |  | ||||||
|     else: |  | ||||||
|         layer_type = "conv" |  | ||||||
|  |  | ||||||
|     set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @torch.no_grad() |  | ||||||
| def convert_wav2vec2_checkpoint( |  | ||||||
|     checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True |  | ||||||
| ): |  | ||||||
|     """ |  | ||||||
|     Copy/paste/tweak model's weights to transformers design. |  | ||||||
|     """ |  | ||||||
|     if config_path is not None: |  | ||||||
|         config = Data2VecAudioConfig.from_pretrained(config_path) |  | ||||||
|     else: |  | ||||||
|         config = Data2VecAudioConfig() |  | ||||||
|  |  | ||||||
|     if not is_finetuned: |  | ||||||
|         # Modify final_proj layer name |  | ||||||
|         hf_wav2vec = Data2VecAudioModel(config) |  | ||||||
|         data2vec_checkpoint_dir = os.path.dirname(checkpoint_path) |  | ||||||
|  |  | ||||||
|         state_dict = torch.load(checkpoint_path) |  | ||||||
|         state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight") |  | ||||||
|         state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias") |  | ||||||
|         converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt") |  | ||||||
|         torch.save(state_dict, converted_ckpt) |  | ||||||
|     else: |  | ||||||
|         hf_wav2vec = Data2VecAudioForCTC(config) |  | ||||||
|         converted_ckpt = checkpoint_path |  | ||||||
|  |  | ||||||
|     def load_data2vec(path): |  | ||||||
|         model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path]) |  | ||||||
|         return model[0].eval() |  | ||||||
|  |  | ||||||
|     model = load_data2vec(converted_ckpt) |  | ||||||
|  |  | ||||||
|     recursively_load_weights(model, hf_wav2vec, not is_finetuned) |  | ||||||
|  |  | ||||||
|     processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60") |  | ||||||
|  |  | ||||||
|     ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True) |  | ||||||
|     input_audio = [x["array"] for x in ds[:4]["audio"]] |  | ||||||
|  |  | ||||||
|     inputs = processor(input_audio, return_tensors="pt", padding=True) |  | ||||||
|  |  | ||||||
|     input_values = inputs.input_values |  | ||||||
|     attention_mask = inputs.attention_mask |  | ||||||
|     #    input_values = inputs.input_values[:, :-1] |  | ||||||
|     #    attention_mask = inputs.attention_mask[:, :-1] |  | ||||||
|  |  | ||||||
|     hf_wav2vec.eval() |  | ||||||
|     model.eval() |  | ||||||
|     if is_finetuned: |  | ||||||
|         their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[ |  | ||||||
|             "encoder_out" |  | ||||||
|         ].transpose(0, 1) |  | ||||||
|         our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"] |  | ||||||
|  |  | ||||||
|         pred_ids = torch.argmax(our_output, dim=-1) |  | ||||||
|         output_string = processor.batch_decode(pred_ids) |  | ||||||
|  |  | ||||||
|         print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}") |  | ||||||
|     else: |  | ||||||
|         their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[ |  | ||||||
|             "layer_results" |  | ||||||
|         ][-1][0].transpose(0, 1) |  | ||||||
|         our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"] |  | ||||||
|  |  | ||||||
|     print(our_output.shape, their_output.shape) |  | ||||||
|     max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() |  | ||||||
|     print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7 |  | ||||||
|     success = torch.allclose(our_output, their_output, atol=1e-3) |  | ||||||
|     print("Do both models output the same tensors?", "🔥" if success else "💩") |  | ||||||
|     if not success: |  | ||||||
|         raise Exception("Something went wRoNg") |  | ||||||
|  |  | ||||||
|     hf_wav2vec.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|     if is_finetuned: |  | ||||||
|         processor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|     else: |  | ||||||
|         processor.feature_extractor.save_pretrained(pytorch_dump_folder_path) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     parser = argparse.ArgumentParser() |  | ||||||
|     parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") |  | ||||||
|     parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") |  | ||||||
|     parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") |  | ||||||
|     parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" |  | ||||||
|     ) |  | ||||||
|     args = parser.parse_args() |  | ||||||
|     convert_wav2vec2_checkpoint( |  | ||||||
|         args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned |  | ||||||
|     ) |  | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	