mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 20:14:36 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			debug_circ
			...
			v4.51.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0720e206c6 | 
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
Array = Any
 | 
			
		||||
Dataset = datasets.arrow_dataset.Dataset
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
# You should update this to your particular problem to have better documentation of `model_type`
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(level=logging.INFO)
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version(
 | 
			
		||||
    "datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Checking dependencies
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
task_to_keys = {
 | 
			
		||||
    "cola": ("sentence", None),
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
# region Dependencies and constants
 | 
			
		||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 | 
			
		||||
check_min_version("4.51.0.dev0")
 | 
			
		||||
check_min_version("4.51.0")
 | 
			
		||||
 | 
			
		||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -453,7 +453,7 @@ install_requires = [
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name="transformers",
 | 
			
		||||
    version="4.51.0.dev0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    version="4.51.0",  # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
 | 
			
		||||
    author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
 | 
			
		||||
    author_email="transformers@huggingface.co",
 | 
			
		||||
    description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@
 | 
			
		||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
 | 
			
		||||
# in the namespace without actually importing anything (and especially none of the backends).
 | 
			
		||||
 | 
			
		||||
__version__ = "4.51.0.dev0"
 | 
			
		||||
__version__ = "4.51.0"
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,62 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ALBERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from ...utils import logging
 | 
			
		||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = AlbertConfig.from_json_file(albert_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = AlbertForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--albert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained ALBERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.albert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,389 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ALIGN checkpoints from the original repository."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import align
 | 
			
		||||
import numpy as np
 | 
			
		||||
import requests
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from tokenizer import Tokenizer
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AlignConfig,
 | 
			
		||||
    AlignModel,
 | 
			
		||||
    AlignProcessor,
 | 
			
		||||
    BertConfig,
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    EfficientNetConfig,
 | 
			
		||||
    EfficientNetImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess(image):
 | 
			
		||||
    image = tf.image.resize(image, (346, 346))
 | 
			
		||||
    image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_align_config():
 | 
			
		||||
    vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7")
 | 
			
		||||
    vision_config.image_size = 289
 | 
			
		||||
    vision_config.hidden_dim = 640
 | 
			
		||||
    vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"}
 | 
			
		||||
    vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1}
 | 
			
		||||
    vision_config.depthwise_padding = []
 | 
			
		||||
 | 
			
		||||
    text_config = BertConfig()
 | 
			
		||||
    config = AlignConfig.from_text_vision_configs(
 | 
			
		||||
        text_config=text_config, vision_config=vision_config, projection_dim=640
 | 
			
		||||
    )
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_processor():
 | 
			
		||||
    image_processor = EfficientNetImageProcessor(
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        rescale_factor=1 / 127.5,
 | 
			
		||||
        rescale_offset=True,
 | 
			
		||||
        do_normalize=False,
 | 
			
		||||
        include_top=False,
 | 
			
		||||
        resample=Image.BILINEAR,
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
 | 
			
		||||
    tokenizer.model_max_length = 64
 | 
			
		||||
    processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    return processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def rename_keys(original_param_names):
 | 
			
		||||
    # EfficientNet image encoder
 | 
			
		||||
    block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
 | 
			
		||||
    block_names = list(set(block_names))
 | 
			
		||||
    block_names = sorted(block_names)
 | 
			
		||||
    num_blocks = len(block_names)
 | 
			
		||||
    block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
 | 
			
		||||
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
 | 
			
		||||
    rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
 | 
			
		||||
    rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
 | 
			
		||||
    rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
 | 
			
		||||
    rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
 | 
			
		||||
 | 
			
		||||
    for b in block_names:
 | 
			
		||||
        hf_b = block_name_mapping[b]
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
 | 
			
		||||
        rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    key_mapping = {}
 | 
			
		||||
    for item in rename_keys:
 | 
			
		||||
        if item[0] in original_param_names:
 | 
			
		||||
            key_mapping[item[0]] = "vision_model." + item[1]
 | 
			
		||||
 | 
			
		||||
    # BERT text encoder
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    old = "tf_bert_model/bert"
 | 
			
		||||
    new = "text_model"
 | 
			
		||||
    for i in range(12):
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.query.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/query/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.query.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.key.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/key/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.key.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.value.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/self/value/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.self.value.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.dense.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.dense.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.intermediate.dense.weight",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (
 | 
			
		||||
                f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0",
 | 
			
		||||
                f"{new}.encoder.layer.{i}.intermediate.dense.bias",
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight"))
 | 
			
		||||
    rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias"))
 | 
			
		||||
 | 
			
		||||
    rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight"))
 | 
			
		||||
    rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias"))
 | 
			
		||||
    rename_keys.append(("dense/kernel:0", "text_projection.weight"))
 | 
			
		||||
    rename_keys.append(("dense/bias:0", "text_projection.bias"))
 | 
			
		||||
    rename_keys.append(("dense/bias:0", "text_projection.bias"))
 | 
			
		||||
    rename_keys.append(("temperature:0", "temperature"))
 | 
			
		||||
 | 
			
		||||
    for item in rename_keys:
 | 
			
		||||
        if item[0] in original_param_names:
 | 
			
		||||
            key_mapping[item[0]] = item[1]
 | 
			
		||||
    return key_mapping
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_params(hf_params, tf_params, key_mapping):
 | 
			
		||||
    list(hf_params.keys())
 | 
			
		||||
 | 
			
		||||
    for key, value in tf_params.items():
 | 
			
		||||
        if key not in key_mapping:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        hf_key = key_mapping[key]
 | 
			
		||||
        if "_conv" in key and "kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
 | 
			
		||||
        elif "embeddings" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value)
 | 
			
		||||
        elif "depthwise_kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
 | 
			
		||||
        elif "kernel" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(np.transpose(value))
 | 
			
		||||
        elif "temperature" in key:
 | 
			
		||||
            new_hf_value = value
 | 
			
		||||
        elif "bn/gamma" or "bn/beta" in key:
 | 
			
		||||
            new_hf_value = torch.from_numpy(np.transpose(value)).squeeze()
 | 
			
		||||
        else:
 | 
			
		||||
            new_hf_value = torch.from_numpy(value)
 | 
			
		||||
 | 
			
		||||
        # Replace HF parameters with original TF model parameters
 | 
			
		||||
        hf_params[hf_key].copy_(new_hf_value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ALIGN structure.
 | 
			
		||||
    """
 | 
			
		||||
    # Load original model
 | 
			
		||||
    seq_length = 64
 | 
			
		||||
    tok = Tokenizer(seq_length)
 | 
			
		||||
    original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size())
 | 
			
		||||
    original_model.compile()
 | 
			
		||||
    original_model.load_weights(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    tf_params = original_model.trainable_variables
 | 
			
		||||
    tf_non_train_params = original_model.non_trainable_variables
 | 
			
		||||
    tf_params = {param.name: param.numpy() for param in tf_params}
 | 
			
		||||
    for param in tf_non_train_params:
 | 
			
		||||
        tf_params[param.name] = param.numpy()
 | 
			
		||||
    tf_param_names = list(tf_params.keys())
 | 
			
		||||
 | 
			
		||||
    # Load HuggingFace model
 | 
			
		||||
    config = get_align_config()
 | 
			
		||||
    hf_model = AlignModel(config).eval()
 | 
			
		||||
    hf_params = hf_model.state_dict()
 | 
			
		||||
 | 
			
		||||
    # Create src-to-dst parameter name mapping dictionary
 | 
			
		||||
    print("Converting parameters...")
 | 
			
		||||
    key_mapping = rename_keys(tf_param_names)
 | 
			
		||||
    replace_params(hf_params, tf_params, key_mapping)
 | 
			
		||||
 | 
			
		||||
    # Initialize processor
 | 
			
		||||
    processor = get_processor()
 | 
			
		||||
    inputs = processor(
 | 
			
		||||
        images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # HF model inference
 | 
			
		||||
    hf_model.eval()
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = hf_model(**inputs)
 | 
			
		||||
 | 
			
		||||
    hf_image_features = outputs.image_embeds.detach().numpy()
 | 
			
		||||
    hf_text_features = outputs.text_embeds.detach().numpy()
 | 
			
		||||
 | 
			
		||||
    # Original model inference
 | 
			
		||||
    original_model.trainable = False
 | 
			
		||||
    tf_image_processor = EfficientNetImageProcessor(
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        do_rescale=False,
 | 
			
		||||
        do_normalize=False,
 | 
			
		||||
        include_top=False,
 | 
			
		||||
        resample=Image.BILINEAR,
 | 
			
		||||
    )
 | 
			
		||||
    image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"]
 | 
			
		||||
    text = tok(tf.constant(["A picture of a cat"]))
 | 
			
		||||
 | 
			
		||||
    image_features = original_model.image_encoder(image, training=False)
 | 
			
		||||
    text_features = original_model.text_encoder(text, training=False)
 | 
			
		||||
 | 
			
		||||
    image_features = tf.nn.l2_normalize(image_features, axis=-1)
 | 
			
		||||
    text_features = tf.nn.l2_normalize(text_features, axis=-1)
 | 
			
		||||
 | 
			
		||||
    # Check whether original and HF model outputs match  -> np.allclose
 | 
			
		||||
    if not np.allclose(image_features, hf_image_features, atol=1e-3):
 | 
			
		||||
        raise ValueError("The predicted image features are not the same.")
 | 
			
		||||
    if not np.allclose(text_features, hf_text_features, atol=1e-3):
 | 
			
		||||
        raise ValueError("The predicted text features are not the same.")
 | 
			
		||||
    print("Model outputs match!")
 | 
			
		||||
 | 
			
		||||
    if save_model:
 | 
			
		||||
        # Create folder to save model
 | 
			
		||||
        if not os.path.isdir(pytorch_dump_folder_path):
 | 
			
		||||
            os.mkdir(pytorch_dump_folder_path)
 | 
			
		||||
        # Save converted model and image processor
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        # Push model and image processor to hub
 | 
			
		||||
        print("Pushing converted ALIGN to the hub...")
 | 
			
		||||
        processor.push_to_hub("align-base")
 | 
			
		||||
        hf_model.push_to_hub("align-base")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path",
 | 
			
		||||
        default="./weights/model-weights",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the pretrained TF ALIGN checkpoint.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default="hf_model",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--save_model", action="store_true", help="Save model to local")
 | 
			
		||||
    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
 | 
			
		||||
@ -1,162 +0,0 @@
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import argparse
 | 
			
		||||
import glob
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AddedToken,
 | 
			
		||||
    AriaForConditionalGeneration,
 | 
			
		||||
    AriaProcessor,
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EPILOG_TXT = """Example:
 | 
			
		||||
    python transformers/src/transformers/models/aria/convert_aria_weights_to_hf.py --text_model_id rhymes-ai/Aria --vision_model_id rhymes-ai/Aria --output_hub_path m-ric/Aria_hf_2 --old_state_dict_id rhymes-ai/Aria
 | 
			
		||||
 | 
			
		||||
Example for creating the old state dict file with Python:
 | 
			
		||||
 | 
			
		||||
    import torch
 | 
			
		||||
    from aria.model.language_model.aria_llama import AriaTextForCausalLM
 | 
			
		||||
 | 
			
		||||
    # load model
 | 
			
		||||
    kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
 | 
			
		||||
    model = AriaTextForCausalLM.from_pretrained("rhymes-ai/Aria", low_cpu_mem_usage=True, **kwargs)
 | 
			
		||||
 | 
			
		||||
    # load vision tower
 | 
			
		||||
    model.get_vision_tower().load_model()
 | 
			
		||||
 | 
			
		||||
    # Save state dict
 | 
			
		||||
    torch.save(model.state_dict(), "tmp/hf_models/aria/model_state_dict.bin")
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
KEYS_TO_MODIFY_MAPPING = {
 | 
			
		||||
    "vision_tower.vision_model": "vision_tower",
 | 
			
		||||
    "ln_ffn": "layer_norm",
 | 
			
		||||
    "ffn": "feed_forward",
 | 
			
		||||
    "ln_kv": "layer_norm_kv",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id):
 | 
			
		||||
    directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    for path in glob.glob(f"{directory_path}/*"):
 | 
			
		||||
        if path.endswith(".safetensors"):
 | 
			
		||||
            with safe_open(path, framework="pt", device="cpu") as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict_to_hf(state_dict):
 | 
			
		||||
    new_state_dict = {}
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        if key.endswith(".inv_freq"):
 | 
			
		||||
            continue
 | 
			
		||||
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
 | 
			
		||||
            if key_to_modify in key:
 | 
			
		||||
                key = key.replace(key_to_modify, new_key)
 | 
			
		||||
 | 
			
		||||
        new_state_dict[key] = value
 | 
			
		||||
    new_state_dict["vision_tower.post_layernorm.weight"] = torch.zeros((1152,))
 | 
			
		||||
    new_state_dict["vision_tower.post_layernorm.bias"] = torch.zeros((1152,))
 | 
			
		||||
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
 | 
			
		||||
    torch.set_default_dtype(torch.float16)
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        text_model_id,
 | 
			
		||||
        extra_special_tokens={
 | 
			
		||||
            "image_token": "<|img|>",
 | 
			
		||||
            "pad_token": "<pad>",
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer.add_tokens(AddedToken("<|img|>", special=True, normalized=False), special_tokens=True)
 | 
			
		||||
    tokenizer.add_special_tokens({"pad_token": "<pad>"})
 | 
			
		||||
    tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<fim_prefix><|img|><fim_suffix>{% endif %}{% endfor %}{% endif %}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
 | 
			
		||||
 | 
			
		||||
    processor = AriaProcessor.from_pretrained(
 | 
			
		||||
        text_model_id,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(text_model_id)
 | 
			
		||||
    config.vision_config.hidden_size = 1152
 | 
			
		||||
    config.vision_config.attention_heads = 16
 | 
			
		||||
    config.pad_token_id = 2
 | 
			
		||||
    config.image_token_index = 9
 | 
			
		||||
    config.intermediate_size = config.moe_intermediate_size
 | 
			
		||||
    config.auto_map = {
 | 
			
		||||
        "AutoConfig": "modeling_aria.AriaConfig",
 | 
			
		||||
        "AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    with torch.device("meta"):
 | 
			
		||||
        model = AriaForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    state_dict = load_original_state_dict(old_state_dict_id)
 | 
			
		||||
 | 
			
		||||
    state_dict = convert_state_dict_to_hf(state_dict)
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False, assign=True)
 | 
			
		||||
 | 
			
		||||
    # print("Saving models")
 | 
			
		||||
    # model.save_pretrained("local_aria", safe_serialization=False)
 | 
			
		||||
    # processor.save_pretrained("local_aria")
 | 
			
		||||
    print("Pushing to hub")
 | 
			
		||||
    model.push_to_hub(output_hub_path, create_pr=True)
 | 
			
		||||
    processor.push_to_hub(output_hub_path, create_pr=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        epilog=EPILOG_TXT,
 | 
			
		||||
        formatter_class=argparse.RawDescriptionHelpFormatter,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--text_model_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Hub location of the text model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--vision_model_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Hub location of the vision model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_hub_path",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Location on the hub of the converted model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--old_state_dict_id",
 | 
			
		||||
        default="rhymes-ai/Aria",
 | 
			
		||||
        help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_aria_llama_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,279 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Audio Spectrogram Transformer checkpoints from the original repository. URL: https://github.com/YuanGongND/ast"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torchaudio
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import ASTConfig, ASTFeatureExtractor, ASTForAudioClassification
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_audio_spectrogram_transformer_config(model_name):
 | 
			
		||||
    config = ASTConfig()
 | 
			
		||||
 | 
			
		||||
    if "10-10" in model_name:
 | 
			
		||||
        pass
 | 
			
		||||
    elif "speech-commands" in model_name:
 | 
			
		||||
        config.max_length = 128
 | 
			
		||||
    elif "12-12" in model_name:
 | 
			
		||||
        config.time_stride = 12
 | 
			
		||||
        config.frequency_stride = 12
 | 
			
		||||
    elif "14-14" in model_name:
 | 
			
		||||
        config.time_stride = 14
 | 
			
		||||
        config.frequency_stride = 14
 | 
			
		||||
    elif "16-16" in model_name:
 | 
			
		||||
        config.time_stride = 16
 | 
			
		||||
        config.frequency_stride = 16
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Model not supported")
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    if "speech-commands" in model_name:
 | 
			
		||||
        config.num_labels = 35
 | 
			
		||||
        filename = "speech-commands-v2-id2label.json"
 | 
			
		||||
    else:
 | 
			
		||||
        config.num_labels = 527
 | 
			
		||||
        filename = "audioset-id2label.json"
 | 
			
		||||
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "module.v" in name:
 | 
			
		||||
        name = name.replace("module.v", "audio_spectrogram_transformer")
 | 
			
		||||
    if "cls_token" in name:
 | 
			
		||||
        name = name.replace("cls_token", "embeddings.cls_token")
 | 
			
		||||
    if "dist_token" in name:
 | 
			
		||||
        name = name.replace("dist_token", "embeddings.distillation_token")
 | 
			
		||||
    if "pos_embed" in name:
 | 
			
		||||
        name = name.replace("pos_embed", "embeddings.position_embeddings")
 | 
			
		||||
    if "patch_embed.proj" in name:
 | 
			
		||||
        name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
 | 
			
		||||
    # transformer blocks
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "encoder.layer")
 | 
			
		||||
    if "attn.proj" in name:
 | 
			
		||||
        name = name.replace("attn.proj", "attention.output.dense")
 | 
			
		||||
    if "attn" in name:
 | 
			
		||||
        name = name.replace("attn", "attention.self")
 | 
			
		||||
    if "norm1" in name:
 | 
			
		||||
        name = name.replace("norm1", "layernorm_before")
 | 
			
		||||
    if "norm2" in name:
 | 
			
		||||
        name = name.replace("norm2", "layernorm_after")
 | 
			
		||||
    if "mlp.fc1" in name:
 | 
			
		||||
        name = name.replace("mlp.fc1", "intermediate.dense")
 | 
			
		||||
    if "mlp.fc2" in name:
 | 
			
		||||
        name = name.replace("mlp.fc2", "output.dense")
 | 
			
		||||
    # final layernorm
 | 
			
		||||
    if "audio_spectrogram_transformer.norm" in name:
 | 
			
		||||
        name = name.replace("audio_spectrogram_transformer.norm", "audio_spectrogram_transformer.layernorm")
 | 
			
		||||
    # classifier head
 | 
			
		||||
    if "module.mlp_head.0" in name:
 | 
			
		||||
        name = name.replace("module.mlp_head.0", "classifier.layernorm")
 | 
			
		||||
    if "module.mlp_head.1" in name:
 | 
			
		||||
        name = name.replace("module.mlp_head.1", "classifier.dense")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, config):
 | 
			
		||||
    for key in orig_state_dict.copy().keys():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
 | 
			
		||||
        if "qkv" in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            layer_num = int(key_split[3])
 | 
			
		||||
            dim = config.hidden_size
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.weight"
 | 
			
		||||
                ] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.weight"
 | 
			
		||||
                ] = val[dim : dim * 2, :]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.weight"
 | 
			
		||||
                ] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.query.bias"
 | 
			
		||||
                ] = val[:dim]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.key.bias"
 | 
			
		||||
                ] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[
 | 
			
		||||
                    f"audio_spectrogram_transformer.encoder.layer.{layer_num}.attention.attention.value.bias"
 | 
			
		||||
                ] = val[-dim:]
 | 
			
		||||
        else:
 | 
			
		||||
            orig_state_dict[rename_key(key)] = val
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_keys(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "module.v.head.weight",
 | 
			
		||||
        "module.v.head.bias",
 | 
			
		||||
        "module.v.head_dist.weight",
 | 
			
		||||
        "module.v.head_dist.bias",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our Audio Spectrogram Transformer structure.
 | 
			
		||||
    """
 | 
			
		||||
    config = get_audio_spectrogram_transformer_config(model_name)
 | 
			
		||||
 | 
			
		||||
    model_name_to_url = {
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.4593": (
 | 
			
		||||
            "https://www.dropbox.com/s/ca0b1v2nlxzyeb4/audioset_10_10_0.4593.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.450": (
 | 
			
		||||
            "https://www.dropbox.com/s/1tv0hovue1bxupk/audioset_10_10_0.4495.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.448": (
 | 
			
		||||
            "https://www.dropbox.com/s/6u5sikl4b9wo4u5/audioset_10_10_0.4483.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-10-10-0.448-v2": (
 | 
			
		||||
            "https://www.dropbox.com/s/kt6i0v9fvfm1mbq/audioset_10_10_0.4475.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-12-12-0.447": (
 | 
			
		||||
            "https://www.dropbox.com/s/snfhx3tizr4nuc8/audioset_12_12_0.4467.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-14-14-0.443": (
 | 
			
		||||
            "https://www.dropbox.com/s/z18s6pemtnxm4k7/audioset_14_14_0.4431.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-audioset-16-16-0.442": (
 | 
			
		||||
            "https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
        "ast-finetuned-speech-commands-v2": (
 | 
			
		||||
            "https://www.dropbox.com/s/q0tbqpwv44pquwy/speechcommands_10_10_0.9812.pth?dl=1"
 | 
			
		||||
        ),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # load original state_dict
 | 
			
		||||
    checkpoint_url = model_name_to_url[model_name]
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
 | 
			
		||||
    # remove some keys
 | 
			
		||||
    remove_keys(state_dict)
 | 
			
		||||
    # rename some keys
 | 
			
		||||
    new_state_dict = convert_state_dict(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    # load 🤗 model
 | 
			
		||||
    model = ASTForAudioClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(new_state_dict)
 | 
			
		||||
 | 
			
		||||
    # verify outputs on dummy input
 | 
			
		||||
    # source: https://github.com/YuanGongND/ast/blob/79e873b8a54d0a3b330dd522584ff2b9926cd581/src/run.py#L62
 | 
			
		||||
    mean = -4.2677393 if "speech-commands" not in model_name else -6.845978
 | 
			
		||||
    std = 4.5689974 if "speech-commands" not in model_name else 5.5654526
 | 
			
		||||
    max_length = 1024 if "speech-commands" not in model_name else 128
 | 
			
		||||
    feature_extractor = ASTFeatureExtractor(mean=mean, std=std, max_length=max_length)
 | 
			
		||||
 | 
			
		||||
    if "speech-commands" in model_name:
 | 
			
		||||
        # TODO: Convert dataset to Parquet
 | 
			
		||||
        dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
 | 
			
		||||
        waveform = dataset[0]["audio"]["array"]
 | 
			
		||||
    else:
 | 
			
		||||
        filepath = hf_hub_download(
 | 
			
		||||
            repo_id="nielsr/audio-spectogram-transformer-checkpoint",
 | 
			
		||||
            filename="sample_audio.flac",
 | 
			
		||||
            repo_type="dataset",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        waveform, _ = torchaudio.load(filepath)
 | 
			
		||||
        waveform = waveform.squeeze().numpy()
 | 
			
		||||
 | 
			
		||||
    inputs = feature_extractor(waveform, sampling_rate=16000, return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
    # forward pass
 | 
			
		||||
    outputs = model(**inputs)
 | 
			
		||||
    logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    if model_name == "ast-finetuned-audioset-10-10-0.4593":
 | 
			
		||||
        expected_slice = torch.tensor([-0.8760, -7.0042, -8.6602])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.450":
 | 
			
		||||
        expected_slice = torch.tensor([-1.1986, -7.0903, -8.2718])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.448":
 | 
			
		||||
        expected_slice = torch.tensor([-2.6128, -8.0080, -9.4344])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-10-10-0.448-v2":
 | 
			
		||||
        expected_slice = torch.tensor([-1.5080, -7.4534, -8.8917])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-12-12-0.447":
 | 
			
		||||
        expected_slice = torch.tensor([-0.5050, -6.5833, -8.0843])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-14-14-0.443":
 | 
			
		||||
        expected_slice = torch.tensor([-0.3826, -7.0336, -8.2413])
 | 
			
		||||
    elif model_name == "ast-finetuned-audioset-16-16-0.442":
 | 
			
		||||
        expected_slice = torch.tensor([-1.2113, -6.9101, -8.3470])
 | 
			
		||||
    elif model_name == "ast-finetuned-speech-commands-v2":
 | 
			
		||||
        expected_slice = torch.tensor([6.1589, -8.0566, -8.7984])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown model name")
 | 
			
		||||
    if not torch.allclose(logits[0, :3], expected_slice, atol=1e-4):
 | 
			
		||||
        raise ValueError("Logits don't match")
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        print(f"Saving feature extractor to {pytorch_dump_folder_path}")
 | 
			
		||||
        feature_extractor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print("Pushing model and feature extractor to the hub...")
 | 
			
		||||
        model.push_to_hub(f"MIT/{model_name}")
 | 
			
		||||
        feature_extractor.push_to_hub(f"MIT/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="ast-finetuned-audioset-10-10-0.4593",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the Audio Spectrogram Transformer model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_audio_spectrogram_transformer_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,273 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from os import path
 | 
			
		||||
from typing import Dict, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import split_torch_state_dict_into_shards
 | 
			
		||||
from safetensors.torch import save_file
 | 
			
		||||
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
 | 
			
		||||
 | 
			
		||||
from .configuration_bamba import BambaConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for orig_k, param in original_sd.items():
 | 
			
		||||
        k = orig_k.replace("backbone", "model")
 | 
			
		||||
 | 
			
		||||
        # for embeddings
 | 
			
		||||
        k = k.replace("embedding", "embed_tokens")
 | 
			
		||||
 | 
			
		||||
        # for mixer
 | 
			
		||||
        k = k.replace("mixer", "mamba")
 | 
			
		||||
 | 
			
		||||
        # for final layernorm
 | 
			
		||||
        k = k.replace("norm_f", "final_layernorm")
 | 
			
		||||
 | 
			
		||||
        # for block layernorm
 | 
			
		||||
        k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
 | 
			
		||||
        k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)
 | 
			
		||||
 | 
			
		||||
        # for mlp
 | 
			
		||||
        k = k.replace("mlp.fc2", "feed_forward.down_proj")
 | 
			
		||||
 | 
			
		||||
        if "mlp.fc1" in k:
 | 
			
		||||
            param, param2 = torch.chunk(param, 2, dim=0)
 | 
			
		||||
            k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
 | 
			
		||||
            state_dict[k2] = param2
 | 
			
		||||
            k = k.replace("mlp.fc1", "feed_forward.up_proj")
 | 
			
		||||
 | 
			
		||||
        if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
 | 
			
		||||
            "out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
 | 
			
		||||
        ):
 | 
			
		||||
            # then this must be a mamba
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            # for attn
 | 
			
		||||
            # - because mixer was replaced to mamba above
 | 
			
		||||
            k = k.replace("mamba.out_proj", "self_attn.o_proj")
 | 
			
		||||
            if "mamba.in_proj" in k:
 | 
			
		||||
                m, n = param.shape
 | 
			
		||||
                d = (m - n) // 2
 | 
			
		||||
                param, param2, param3 = torch.split(param, [n, d, d], dim=0)
 | 
			
		||||
                k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
 | 
			
		||||
                state_dict[k2] = param2
 | 
			
		||||
                k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
 | 
			
		||||
                state_dict[k2] = param3
 | 
			
		||||
                k = k.replace("mamba.in_proj", "self_attn.q_proj")
 | 
			
		||||
 | 
			
		||||
        state_dict[k] = param
 | 
			
		||||
 | 
			
		||||
    return state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
 | 
			
		||||
def convert_ssm_config_to_hf_config(
 | 
			
		||||
    config_ssm: Dict,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> BambaConfig:
 | 
			
		||||
    """Convert a config from mamba_ssm to a BambaConfig from here."""
 | 
			
		||||
    hf_config: BambaConfig = BambaConfig(**kwargs)
 | 
			
		||||
 | 
			
		||||
    hf_config.architectures = ["BambaForCausalLM"]
 | 
			
		||||
 | 
			
		||||
    # Set important values from config and recalculate other resulting entries
 | 
			
		||||
    hf_config.hidden_size = config_ssm["d_model"]
 | 
			
		||||
    hf_config.intermediate_size = config_ssm["d_intermediate"]
 | 
			
		||||
    hf_config.mamba_n_heads = (hf_config.hidden_size * hf_config.mamba_expand) // hf_config.mamba_d_head
 | 
			
		||||
    hf_config.num_hidden_layers = config_ssm["n_layer"]
 | 
			
		||||
    hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
 | 
			
		||||
 | 
			
		||||
    # currently this script assumes config_ssm belongs to v2
 | 
			
		||||
    if config_ssm["ssm_cfg"].get("layer") != "Mamba2":
 | 
			
		||||
        raise ValueError("Conversion script only supports Mamba2")
 | 
			
		||||
 | 
			
		||||
    # Set attention values
 | 
			
		||||
    attn_cfg = config_ssm.get("attn_cfg")
 | 
			
		||||
    if attn_cfg:
 | 
			
		||||
        assert attn_cfg["causal"], "Only support non-causal attention."
 | 
			
		||||
        assert not attn_cfg["qkv_proj_bias"], "Only support no qkv bias."
 | 
			
		||||
        assert not attn_cfg["out_proj_bias"], "Only support no out bias."
 | 
			
		||||
        hf_config.attn_rotary_emb = attn_cfg["rotary_emb_dim"]
 | 
			
		||||
        hf_config.num_attention_heads = attn_cfg["num_heads"]
 | 
			
		||||
        hf_config.num_key_value_heads = attn_cfg["num_heads_kv"]
 | 
			
		||||
 | 
			
		||||
    attention_layer_indices = config_ssm.get("attn_layer_idx")
 | 
			
		||||
    if attention_layer_indices:
 | 
			
		||||
        hf_config.attn_layer_indices = attention_layer_indices
 | 
			
		||||
 | 
			
		||||
    # Padded vocab size, mostly of 16 but 32 is also very common in different models
 | 
			
		||||
    vocab_size = config_ssm["vocab_size"]
 | 
			
		||||
    pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
 | 
			
		||||
    if (vocab_size % pad_vocab_size_multiple) != 0:
 | 
			
		||||
        vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
 | 
			
		||||
    hf_config.vocab_size = vocab_size
 | 
			
		||||
 | 
			
		||||
    return hf_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_single_safetensor(
 | 
			
		||||
    state_dict: Dict,
 | 
			
		||||
    save_directory: str,
 | 
			
		||||
    metadata: Dict,
 | 
			
		||||
):
 | 
			
		||||
    save_file(
 | 
			
		||||
        state_dict,
 | 
			
		||||
        os.path.join(save_directory, SAFE_WEIGHTS_NAME),
 | 
			
		||||
        metadata,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_sharded_safetensors(
 | 
			
		||||
    state_dict: Dict,
 | 
			
		||||
    save_directory: str,
 | 
			
		||||
    metadata: Dict,
 | 
			
		||||
    max_shard_size: Union[int, str] = "5GB",
 | 
			
		||||
):
 | 
			
		||||
    filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
 | 
			
		||||
        ".safetensors", "{suffix}.safetensors"
 | 
			
		||||
    )
 | 
			
		||||
    state_dict_split = split_torch_state_dict_into_shards(
 | 
			
		||||
        state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
 | 
			
		||||
    )
 | 
			
		||||
    index = {
 | 
			
		||||
        "metadata": state_dict_split.metadata,
 | 
			
		||||
        "weight_map": state_dict_split.tensor_to_filename,
 | 
			
		||||
    }
 | 
			
		||||
    # Save the index
 | 
			
		||||
    with open(os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
 | 
			
		||||
        content = json.dumps(index, indent=2, sort_keys=True) + "\n"
 | 
			
		||||
        f.write(content)
 | 
			
		||||
 | 
			
		||||
    filename_to_tensors = state_dict_split.filename_to_tensors.items()
 | 
			
		||||
    for shard_file, tensors in filename_to_tensors:
 | 
			
		||||
        shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
 | 
			
		||||
        save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Adapted from transformers.models.mamba.convert_mamba_ssm_checkpoint_to_pytorch.py
 | 
			
		||||
def convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
 | 
			
		||||
    mamba_ssm_checkpoint_path: str,
 | 
			
		||||
    precision: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    tokenizer_path: Optional[str] = None,
 | 
			
		||||
    save_model: Union[bool, str] = True,
 | 
			
		||||
) -> None:
 | 
			
		||||
    # load tokenizer if provided, this will be used to set the
 | 
			
		||||
    # token_ids in the config file
 | 
			
		||||
    token_ids = {}
 | 
			
		||||
    if tokenizer_path:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
 | 
			
		||||
        for key in [
 | 
			
		||||
            "bos_token_id",
 | 
			
		||||
            "eos_token_id",
 | 
			
		||||
            "pad_token_id",
 | 
			
		||||
        ]:
 | 
			
		||||
            id = getattr(tokenizer, key, None)
 | 
			
		||||
            if id:
 | 
			
		||||
                token_ids[key] = id
 | 
			
		||||
 | 
			
		||||
    # there are some configs unsettable by mamba_ssn config, so
 | 
			
		||||
    # if there are changes from the defaults, have to pass them into
 | 
			
		||||
    # the function
 | 
			
		||||
    unsettables = {
 | 
			
		||||
        "mamba_d_head": 64,
 | 
			
		||||
        "mamba_d_state": 128,
 | 
			
		||||
        "mamba_n_groups": 1,
 | 
			
		||||
        "rms_norm_eps": 1e-5,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Load and save config based on name
 | 
			
		||||
    config_path = path.join(mamba_ssm_checkpoint_path, "config.json")
 | 
			
		||||
    with open(config_path, "r", encoding="utf-8") as json_file:
 | 
			
		||||
        config = json.load(json_file)
 | 
			
		||||
 | 
			
		||||
    # convert the config
 | 
			
		||||
    hf_config = convert_ssm_config_to_hf_config(
 | 
			
		||||
        config_ssm=config,
 | 
			
		||||
        **token_ids,
 | 
			
		||||
        **unsettables,
 | 
			
		||||
    )
 | 
			
		||||
    hf_config.save_pretrained(output_dir)
 | 
			
		||||
 | 
			
		||||
    # Load state dict of the original model and transfer to hf model
 | 
			
		||||
    state_dict = torch.load(
 | 
			
		||||
        path.join(mamba_ssm_checkpoint_path, "pytorch_model.bin"),
 | 
			
		||||
        map_location="cpu",
 | 
			
		||||
        weights_only=True,
 | 
			
		||||
    )
 | 
			
		||||
    # FIXME: allow other parameters to pass in
 | 
			
		||||
    state_dict = convert_state_dict_from_mamba_ssm(state_dict)
 | 
			
		||||
 | 
			
		||||
    # Save new model to pytorch_dump_path
 | 
			
		||||
    dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
 | 
			
		||||
 | 
			
		||||
    save_file_fn = None
 | 
			
		||||
    if isinstance(save_model, bool) and save_model:
 | 
			
		||||
        save_file_fn = save_single_safetensor
 | 
			
		||||
    elif isinstance(save_model, str) and save_model == "sharded":
 | 
			
		||||
        save_file_fn = save_sharded_safetensors
 | 
			
		||||
 | 
			
		||||
    if save_file_fn:
 | 
			
		||||
        save_file_fn({k: v.to(dtype) for k, v in state_dict.items()}, output_dir, metadata={"format": "pt"})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-i",
 | 
			
		||||
        "--mamba_ssm_checkpoint_directory",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-p",
 | 
			
		||||
        "--precision",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="fp16",
 | 
			
		||||
        const="fp16",
 | 
			
		||||
        required=True,
 | 
			
		||||
        choices=("fp32", "fp16", "bf16"),
 | 
			
		||||
        help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "-t",
 | 
			
		||||
        "--tokenizer_model_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        required=False,
 | 
			
		||||
        help="Path to a the tokenizer file.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_mamba_ssm_checkpoint_file_to_huggingface_model_file(
 | 
			
		||||
        args.mamba2_checkpoint_directory,
 | 
			
		||||
        args.precision,
 | 
			
		||||
        args.output_dir,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,263 +0,0 @@
 | 
			
		||||
"""Convert Bark checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from bark.generation import _load_model as _bark_load_model
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import EncodecConfig, EncodecModel, set_seed
 | 
			
		||||
from transformers.models.bark.configuration_bark import (
 | 
			
		||||
    BarkCoarseConfig,
 | 
			
		||||
    BarkConfig,
 | 
			
		||||
    BarkFineConfig,
 | 
			
		||||
    BarkSemanticConfig,
 | 
			
		||||
)
 | 
			
		||||
from transformers.models.bark.generation_configuration_bark import (
 | 
			
		||||
    BarkCoarseGenerationConfig,
 | 
			
		||||
    BarkFineGenerationConfig,
 | 
			
		||||
    BarkGenerationConfig,
 | 
			
		||||
    BarkSemanticGenerationConfig,
 | 
			
		||||
)
 | 
			
		||||
from transformers.models.bark.modeling_bark import BarkCoarseModel, BarkFineModel, BarkModel, BarkSemanticModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
set_seed(770)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
new_layer_name_dict = {
 | 
			
		||||
    "c_attn": "att_proj",
 | 
			
		||||
    "c_proj": "out_proj",
 | 
			
		||||
    "c_fc": "in_proj",
 | 
			
		||||
    "transformer.": "",
 | 
			
		||||
    "h.": "layers.",
 | 
			
		||||
    "ln_1": "layernorm_1",
 | 
			
		||||
    "ln_2": "layernorm_2",
 | 
			
		||||
    "ln_f": "layernorm_final",
 | 
			
		||||
    "wpe": "position_embeds_layer",
 | 
			
		||||
    "wte": "input_embeds_layer",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
REMOTE_MODEL_PATHS = {
 | 
			
		||||
    "text_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "text.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "coarse_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "coarse.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "fine_small": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "fine.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "text": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "text_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "coarse": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "coarse_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
    "fine": {
 | 
			
		||||
        "repo_id": "suno/bark",
 | 
			
		||||
        "file_name": "fine_2.pt",
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
 | 
			
		||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_ckpt_path(model_type, use_small=False):
 | 
			
		||||
    key = model_type
 | 
			
		||||
    if use_small:
 | 
			
		||||
        key += "_small"
 | 
			
		||||
    return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _download(from_hf_path, file_name):
 | 
			
		||||
    os.makedirs(CACHE_DIR, exist_ok=True)
 | 
			
		||||
    hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
 | 
			
		||||
    if model_type == "text":
 | 
			
		||||
        ModelClass = BarkSemanticModel
 | 
			
		||||
        ConfigClass = BarkSemanticConfig
 | 
			
		||||
        GenerationConfigClass = BarkSemanticGenerationConfig
 | 
			
		||||
    elif model_type == "coarse":
 | 
			
		||||
        ModelClass = BarkCoarseModel
 | 
			
		||||
        ConfigClass = BarkCoarseConfig
 | 
			
		||||
        GenerationConfigClass = BarkCoarseGenerationConfig
 | 
			
		||||
    elif model_type == "fine":
 | 
			
		||||
        ModelClass = BarkFineModel
 | 
			
		||||
        ConfigClass = BarkFineConfig
 | 
			
		||||
        GenerationConfigClass = BarkFineGenerationConfig
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
    model_key = f"{model_type}_small" if use_small else model_type
 | 
			
		||||
    model_info = REMOTE_MODEL_PATHS[model_key]
 | 
			
		||||
    if not os.path.exists(ckpt_path):
 | 
			
		||||
        logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
 | 
			
		||||
        _download(model_info["repo_id"], model_info["file_name"])
 | 
			
		||||
    checkpoint = torch.load(ckpt_path, map_location=device)
 | 
			
		||||
    # this is a hack
 | 
			
		||||
    model_args = checkpoint["model_args"]
 | 
			
		||||
    if "input_vocab_size" not in model_args:
 | 
			
		||||
        model_args["input_vocab_size"] = model_args["vocab_size"]
 | 
			
		||||
        model_args["output_vocab_size"] = model_args["vocab_size"]
 | 
			
		||||
        del model_args["vocab_size"]
 | 
			
		||||
 | 
			
		||||
    # convert Bark model arguments to HF Bark model arguments
 | 
			
		||||
    model_args["num_heads"] = model_args.pop("n_head")
 | 
			
		||||
    model_args["hidden_size"] = model_args.pop("n_embd")
 | 
			
		||||
    model_args["num_layers"] = model_args.pop("n_layer")
 | 
			
		||||
 | 
			
		||||
    model_config = ConfigClass(**checkpoint["model_args"])
 | 
			
		||||
    model = ModelClass(config=model_config)
 | 
			
		||||
    model_generation_config = GenerationConfigClass()
 | 
			
		||||
 | 
			
		||||
    model.generation_config = model_generation_config
 | 
			
		||||
    state_dict = checkpoint["model"]
 | 
			
		||||
    # fixup checkpoint
 | 
			
		||||
    unwanted_prefix = "_orig_mod."
 | 
			
		||||
    for k, v in list(state_dict.items()):
 | 
			
		||||
        if k.startswith(unwanted_prefix):
 | 
			
		||||
            # replace part of the key with corresponding layer name in HF implementation
 | 
			
		||||
            new_k = k[len(unwanted_prefix) :]
 | 
			
		||||
            for old_layer_name in new_layer_name_dict:
 | 
			
		||||
                new_k = new_k.replace(old_layer_name, new_layer_name_dict[old_layer_name])
 | 
			
		||||
 | 
			
		||||
            state_dict[new_k] = state_dict.pop(k)
 | 
			
		||||
 | 
			
		||||
    extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
 | 
			
		||||
    extra_keys = {k for k in extra_keys if not k.endswith(".attn.bias")}
 | 
			
		||||
    missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
 | 
			
		||||
    missing_keys = {k for k in missing_keys if not k.endswith(".attn.bias")}
 | 
			
		||||
    if len(extra_keys) != 0:
 | 
			
		||||
        raise ValueError(f"extra keys found: {extra_keys}")
 | 
			
		||||
    if len(missing_keys) != 0:
 | 
			
		||||
        raise ValueError(f"missing keys: {missing_keys}")
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
    n_params = model.num_parameters(exclude_embeddings=True)
 | 
			
		||||
    val_loss = checkpoint["best_val_loss"].item()
 | 
			
		||||
    logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params, {round(val_loss, 3)} loss")
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    del checkpoint, state_dict
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(pytorch_dump_folder_path, use_small=False, model_type="text"):
 | 
			
		||||
    if model_type not in ("text", "coarse", "fine"):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    device = "cpu"  # do conversion on cpu
 | 
			
		||||
 | 
			
		||||
    ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
 | 
			
		||||
    model = _load_model(ckpt_path, device, model_type=model_type, use_small=use_small)
 | 
			
		||||
 | 
			
		||||
    # load bark initial model
 | 
			
		||||
    bark_model = _bark_load_model(ckpt_path, "cpu", model_type=model_type, use_small=use_small)
 | 
			
		||||
 | 
			
		||||
    if model_type == "text":
 | 
			
		||||
        bark_model = bark_model["model"]
 | 
			
		||||
 | 
			
		||||
    if model.num_parameters(exclude_embeddings=True) != bark_model.get_num_params():
 | 
			
		||||
        raise ValueError("initial and new models don't have the same number of parameters")
 | 
			
		||||
 | 
			
		||||
    # check if same output as the bark model
 | 
			
		||||
    batch_size = 5
 | 
			
		||||
    sequence_length = 10
 | 
			
		||||
 | 
			
		||||
    if model_type in ["text", "coarse"]:
 | 
			
		||||
        vec = torch.randint(256, (batch_size, sequence_length), dtype=torch.int)
 | 
			
		||||
        output_old_model = bark_model(vec)[0]
 | 
			
		||||
 | 
			
		||||
        output_new_model_total = model(vec)
 | 
			
		||||
 | 
			
		||||
        # take last logits
 | 
			
		||||
        output_new_model = output_new_model_total.logits[:, [-1], :]
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        prediction_codebook_channel = 3
 | 
			
		||||
        n_codes_total = 8
 | 
			
		||||
        vec = torch.randint(256, (batch_size, sequence_length, n_codes_total), dtype=torch.int)
 | 
			
		||||
 | 
			
		||||
        output_new_model_total = model(prediction_codebook_channel, vec)
 | 
			
		||||
        output_old_model = bark_model(prediction_codebook_channel, vec)
 | 
			
		||||
 | 
			
		||||
        output_new_model = output_new_model_total.logits
 | 
			
		||||
 | 
			
		||||
    # output difference should come from the difference of self-attention implementation design
 | 
			
		||||
    if output_new_model.shape != output_old_model.shape:
 | 
			
		||||
        raise ValueError("initial and new outputs don't have the same shape")
 | 
			
		||||
    if (output_new_model - output_old_model).abs().max().item() > 1e-3:
 | 
			
		||||
        raise ValueError("initial and new outputs are not equal")
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_whole_bark_model(
 | 
			
		||||
    semantic_path,
 | 
			
		||||
    coarse_path,
 | 
			
		||||
    fine_path,
 | 
			
		||||
    append_text,
 | 
			
		||||
    hub_path,
 | 
			
		||||
    folder_path,
 | 
			
		||||
):
 | 
			
		||||
    pytorch_dump_folder_path = os.path.join(folder_path, append_text)
 | 
			
		||||
 | 
			
		||||
    semanticConfig = BarkSemanticConfig.from_pretrained(os.path.join(semantic_path, "config.json"))
 | 
			
		||||
    coarseAcousticConfig = BarkCoarseConfig.from_pretrained(os.path.join(coarse_path, "config.json"))
 | 
			
		||||
    fineAcousticConfig = BarkFineConfig.from_pretrained(os.path.join(fine_path, "config.json"))
 | 
			
		||||
    codecConfig = EncodecConfig.from_pretrained("facebook/encodec_24khz")
 | 
			
		||||
 | 
			
		||||
    semantic = BarkSemanticModel.from_pretrained(semantic_path)
 | 
			
		||||
    coarseAcoustic = BarkCoarseModel.from_pretrained(coarse_path)
 | 
			
		||||
    fineAcoustic = BarkFineModel.from_pretrained(fine_path)
 | 
			
		||||
    codec = EncodecModel.from_pretrained("facebook/encodec_24khz")
 | 
			
		||||
 | 
			
		||||
    bark_config = BarkConfig.from_sub_model_configs(
 | 
			
		||||
        semanticConfig, coarseAcousticConfig, fineAcousticConfig, codecConfig
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    bark_generation_config = BarkGenerationConfig.from_sub_model_configs(
 | 
			
		||||
        semantic.generation_config, coarseAcoustic.generation_config, fineAcoustic.generation_config
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    bark = BarkModel(bark_config)
 | 
			
		||||
 | 
			
		||||
    bark.semantic = semantic
 | 
			
		||||
    bark.coarse_acoustics = coarseAcoustic
 | 
			
		||||
    bark.fine_acoustics = fineAcoustic
 | 
			
		||||
    bark.codec_model = codec
 | 
			
		||||
 | 
			
		||||
    bark.generation_config = bark_generation_config
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    bark.save_pretrained(pytorch_dump_folder_path, repo_id=hub_path, push_to_hub=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("model_type", type=str, help="text, coarse or fine.")
 | 
			
		||||
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--is_small", action="store_true", help="convert the small version instead of the large.")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    load_model(args.pytorch_dump_folder_path, model_type=args.model_type, use_small=args.is_small)
 | 
			
		||||
@ -1,156 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BART checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import fairseq
 | 
			
		||||
import torch
 | 
			
		||||
from packaging import version
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BartConfig,
 | 
			
		||||
    BartForConditionalGeneration,
 | 
			
		||||
    BartForSequenceClassification,
 | 
			
		||||
    BartModel,
 | 
			
		||||
    BartTokenizer,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
 | 
			
		||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
 | 
			
		||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
 | 
			
		||||
    raise Exception("requires fairseq >= 0.9.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
 | 
			
		||||
 | 
			
		||||
mnli_rename_keys = [
 | 
			
		||||
    ("model.classification_heads.mnli.dense.weight", "classification_head.dense.weight"),
 | 
			
		||||
    ("model.classification_heads.mnli.dense.bias", "classification_head.dense.bias"),
 | 
			
		||||
    ("model.classification_heads.mnli.out_proj.weight", "classification_head.out_proj.weight"),
 | 
			
		||||
    ("model.classification_heads.mnli.out_proj.bias", "classification_head.out_proj.bias"),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_ignore_keys_(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "encoder.version",
 | 
			
		||||
        "decoder.version",
 | 
			
		||||
        "model.encoder.version",
 | 
			
		||||
        "model.decoder.version",
 | 
			
		||||
        "_float_tensor",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_xsum_checkpoint(checkpoint_path):
 | 
			
		||||
    """Checkpoint path should end in model.pt"""
 | 
			
		||||
    sd = torch.load(checkpoint_path, map_location="cpu")
 | 
			
		||||
    hub_interface = torch.hub.load("pytorch/fairseq", "bart.large.cnn").eval()
 | 
			
		||||
    hub_interface.model.load_state_dict(sd["model"])
 | 
			
		||||
    return hub_interface
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_linear_from_emb(emb):
 | 
			
		||||
    vocab_size, emb_size = emb.weight.shape
 | 
			
		||||
    lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
 | 
			
		||||
    lin_layer.weight.data = emb.weight.data
 | 
			
		||||
    return lin_layer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkpoint_name=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BERT structure.
 | 
			
		||||
    """
 | 
			
		||||
    if not os.path.exists(checkpoint_path):
 | 
			
		||||
        bart = torch.hub.load("pytorch/fairseq", checkpoint_path).eval()
 | 
			
		||||
    else:
 | 
			
		||||
        bart = load_xsum_checkpoint(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    bart.model.upgrade_state_dict(bart.model.state_dict())
 | 
			
		||||
    if hf_checkpoint_name is None:
 | 
			
		||||
        hf_checkpoint_name = checkpoint_path.replace(".", "-")
 | 
			
		||||
    config = BartConfig.from_pretrained(hf_checkpoint_name)
 | 
			
		||||
    tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
 | 
			
		||||
    tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
 | 
			
		||||
    if not torch.eq(tokens, tokens2).all():
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if checkpoint_path == "bart.large.mnli":
 | 
			
		||||
        state_dict = bart.state_dict()
 | 
			
		||||
        remove_ignore_keys_(state_dict)
 | 
			
		||||
        state_dict["model.shared.weight"] = state_dict["model.decoder.embed_tokens.weight"]
 | 
			
		||||
        for src, dest in mnli_rename_keys:
 | 
			
		||||
            rename_key(state_dict, src, dest)
 | 
			
		||||
        model = BartForSequenceClassification(config).eval()
 | 
			
		||||
        model.load_state_dict(state_dict)
 | 
			
		||||
        fairseq_output = bart.predict("mnli", tokens, return_logits=True)
 | 
			
		||||
        new_model_outputs = model(tokens)[0]  # logits
 | 
			
		||||
    else:  # no classification heads to worry about
 | 
			
		||||
        state_dict = bart.model.state_dict()
 | 
			
		||||
        remove_ignore_keys_(state_dict)
 | 
			
		||||
        state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
 | 
			
		||||
        fairseq_output = bart.extract_features(tokens)
 | 
			
		||||
        if hf_checkpoint_name == "facebook/bart-large":
 | 
			
		||||
            model = BartModel(config).eval()
 | 
			
		||||
            model.load_state_dict(state_dict)
 | 
			
		||||
            new_model_outputs = model(tokens).model[0]
 | 
			
		||||
        else:
 | 
			
		||||
            model = BartForConditionalGeneration(config).eval()  # an existing summarization ckpt
 | 
			
		||||
            model.model.load_state_dict(state_dict)
 | 
			
		||||
            if hasattr(model, "lm_head"):
 | 
			
		||||
                model.lm_head = make_linear_from_emb(model.model.shared)
 | 
			
		||||
            new_model_outputs = model.model(tokens)[0]
 | 
			
		||||
 | 
			
		||||
    # Check results
 | 
			
		||||
    if fairseq_output.shape != new_model_outputs.shape:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
 | 
			
		||||
        )
 | 
			
		||||
    if (fairseq_output != new_model_outputs).any().item():
 | 
			
		||||
        raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "fairseq_path", type=str, help="bart.large, bart.large.cnn or a path to a model.pt on local filesystem."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_config", default=None, type=str, help="Which huggingface architecture to use: bart-large-xsum"
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bart_checkpoint(args.fairseq_path, args.pytorch_dump_folder_path, hf_checkpoint_name=args.hf_config)
 | 
			
		||||
@ -1,373 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BEiT checkpoints from the unilm repository."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BeitConfig,
 | 
			
		||||
    BeitForImageClassification,
 | 
			
		||||
    BeitForMaskedImageModeling,
 | 
			
		||||
    BeitForSemanticSegmentation,
 | 
			
		||||
    BeitImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
 | 
			
		||||
    prefix = "backbone." if is_semantic else ""
 | 
			
		||||
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
 | 
			
		||||
 | 
			
		||||
    # projection layer + position embeddings
 | 
			
		||||
    rename_keys.extend(
 | 
			
		||||
        [
 | 
			
		||||
            (f"{prefix}cls_token", "beit.embeddings.cls_token"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if has_lm_head:
 | 
			
		||||
        # mask token + shared relative position bias + layernorm
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("mask_token", "beit.embeddings.mask_token"),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_bias_table",
 | 
			
		||||
                    "beit.encoder.relative_position_bias.relative_position_bias_table",
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_index",
 | 
			
		||||
                    "beit.encoder.relative_position_bias.relative_position_index",
 | 
			
		||||
                ),
 | 
			
		||||
                ("norm.weight", "layernorm.weight"),
 | 
			
		||||
                ("norm.bias", "layernorm.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    elif is_semantic:
 | 
			
		||||
        # semantic segmentation classification heads
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
 | 
			
		||||
                ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # layernorm + classification head
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("fc_norm.weight", "beit.pooler.layernorm.weight"),
 | 
			
		||||
                ("fc_norm.bias", "beit.pooler.layernorm.bias"),
 | 
			
		||||
                ("head.weight", "classifier.weight"),
 | 
			
		||||
                ("head.bias", "classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# we split up the matrix of each encoder layer into queries, keys and values
 | 
			
		||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        prefix = "backbone." if is_semantic else ""
 | 
			
		||||
        # queries, keys and values
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
 | 
			
		||||
        q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
 | 
			
		||||
        v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
 | 
			
		||||
            : config.hidden_size, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
 | 
			
		||||
            config.hidden_size : config.hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
 | 
			
		||||
            -config.hidden_size :, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
 | 
			
		||||
 | 
			
		||||
        # gamma_1 and gamma_2
 | 
			
		||||
        # we call them lambda because otherwise they are renamed when using .from_pretrained
 | 
			
		||||
        gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
 | 
			
		||||
        gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
 | 
			
		||||
        state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
 | 
			
		||||
 | 
			
		||||
        # relative_position bias table + index
 | 
			
		||||
        if not has_lm_head:
 | 
			
		||||
            # each layer has its own relative position bias
 | 
			
		||||
            table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
 | 
			
		||||
            index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
 | 
			
		||||
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
 | 
			
		||||
            ] = table
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
 | 
			
		||||
            ] = index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BEiT structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define default BEiT configuration
 | 
			
		||||
    config = BeitConfig()
 | 
			
		||||
    has_lm_head = False
 | 
			
		||||
    is_semantic = False
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    # set config parameters based on URL
 | 
			
		||||
    if checkpoint_url[-9:-4] == "pt22k":
 | 
			
		||||
        # masked image modeling
 | 
			
		||||
        config.use_shared_relative_position_bias = True
 | 
			
		||||
        config.use_mask_token = True
 | 
			
		||||
        has_lm_head = True
 | 
			
		||||
    elif checkpoint_url[-9:-4] == "ft22k":
 | 
			
		||||
        # intermediate fine-tuning on ImageNet-22k
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 21841
 | 
			
		||||
        filename = "imagenet-22k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        # this dataset contains 21843 labels but the model only has 21841
 | 
			
		||||
        # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
 | 
			
		||||
        del id2label[9205]
 | 
			
		||||
        del id2label[15027]
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    elif checkpoint_url[-8:-4] == "to1k":
 | 
			
		||||
        # fine-tuning on ImageNet-1k
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 1000
 | 
			
		||||
        filename = "imagenet-1k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
        if "384" in checkpoint_url:
 | 
			
		||||
            config.image_size = 384
 | 
			
		||||
        if "512" in checkpoint_url:
 | 
			
		||||
            config.image_size = 512
 | 
			
		||||
    elif "ade20k" in checkpoint_url:
 | 
			
		||||
        # fine-tuning
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.num_labels = 150
 | 
			
		||||
        filename = "ade20k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
        config.image_size = 640
 | 
			
		||||
        is_semantic = True
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
 | 
			
		||||
 | 
			
		||||
    # size of the architecture
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        pass
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        config.hidden_size = 1024
 | 
			
		||||
        config.intermediate_size = 4096
 | 
			
		||||
        config.num_hidden_layers = 24
 | 
			
		||||
        config.num_attention_heads = 16
 | 
			
		||||
        if "ade20k" in checkpoint_url:
 | 
			
		||||
            config.image_size = 640
 | 
			
		||||
            config.out_indices = [7, 11, 15, 23]
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
 | 
			
		||||
 | 
			
		||||
    # load state_dict of original model, remove and rename some keys
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
 | 
			
		||||
    state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
 | 
			
		||||
 | 
			
		||||
    rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
 | 
			
		||||
    if is_semantic:
 | 
			
		||||
        # add prefix to decoder keys
 | 
			
		||||
        for key, val in state_dict.copy().items():
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            if key.startswith("backbone.fpn"):
 | 
			
		||||
                key = key.replace("backbone.fpn", "fpn")
 | 
			
		||||
            state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    if checkpoint_url[-9:-4] == "pt22k":
 | 
			
		||||
        model = BeitForMaskedImageModeling(config)
 | 
			
		||||
    elif "ade20k" in checkpoint_url:
 | 
			
		||||
        model = BeitForSemanticSegmentation(config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BeitForImageClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image
 | 
			
		||||
    if is_semantic:
 | 
			
		||||
        image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
 | 
			
		||||
        ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
 | 
			
		||||
        image = Image.open(ds[0]["file"])
 | 
			
		||||
    else:
 | 
			
		||||
        image_processor = BeitImageProcessor(
 | 
			
		||||
            size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
 | 
			
		||||
        )
 | 
			
		||||
        image = prepare_img()
 | 
			
		||||
 | 
			
		||||
    encoding = image_processor(images=image, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
    logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    expected_shape = torch.Size([1, 1000])
 | 
			
		||||
    if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 196, 8192])
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 196, 8192])
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 21841])
 | 
			
		||||
        expected_logits = torch.tensor([2.2288, 2.4671, 0.7395])
 | 
			
		||||
        expected_class_idx = 2397
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"):
 | 
			
		||||
        expected_shape = torch.Size([1, 21841])
 | 
			
		||||
        expected_logits = torch.tensor([1.6881, -0.2787, 0.5901])
 | 
			
		||||
        expected_class_idx = 2396
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"):
 | 
			
		||||
        expected_logits = torch.tensor([0.1241, 0.0798, -0.6569])
 | 
			
		||||
        expected_class_idx = 285
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108])
 | 
			
		||||
        expected_class_idx = 281
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"):
 | 
			
		||||
        expected_logits = torch.tensor([0.4610, -0.0928, 0.2086])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
 | 
			
		||||
        expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
 | 
			
		||||
        expected_class_idx = 761
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
 | 
			
		||||
        expected_shape = (1, 150, 160, 160)
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
 | 
			
		||||
                [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
 | 
			
		||||
                [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
 | 
			
		||||
        expected_shape = (1, 150, 160, 160)
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
 | 
			
		||||
                [[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
 | 
			
		||||
                [[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Can't verify logits as model is not supported")
 | 
			
		||||
 | 
			
		||||
    if logits.shape != expected_shape:
 | 
			
		||||
        raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
 | 
			
		||||
    if not has_lm_head:
 | 
			
		||||
        if is_semantic:
 | 
			
		||||
            if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
 | 
			
		||||
                raise ValueError("First elements of logits not as expected")
 | 
			
		||||
        else:
 | 
			
		||||
            print("Predicted class idx:", logits.argmax(-1).item())
 | 
			
		||||
 | 
			
		||||
            if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
 | 
			
		||||
                raise ValueError("First elements of logits not as expected")
 | 
			
		||||
            if logits.argmax(-1).item() != expected_class_idx:
 | 
			
		||||
                raise ValueError("Predicted class index not as expected")
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    print(f"Saving model to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Saving image processor to {pytorch_dump_folder_path}")
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL to the original PyTorch checkpoint (.pth file).",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,246 +0,0 @@
 | 
			
		||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This script can be used to convert a head-less TF2.x Bert model to PyTorch, as published on the official (now
 | 
			
		||||
deprecated) GitHub: https://github.com/tensorflow/models/tree/v2.3.0/official/nlp/bert
 | 
			
		||||
 | 
			
		||||
TF2.x uses different variable names from the original BERT (TF 1.4) implementation. The script re-maps the TF2.x Bert
 | 
			
		||||
weight names to the original names, so the model can be imported with Huggingface/transformer.
 | 
			
		||||
 | 
			
		||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
 | 
			
		||||
 | 
			
		||||
Note: This script is only working with an older version of the TensorFlow models repository (<= v2.3.0).
 | 
			
		||||
      Models trained with never versions are not compatible with this script.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertModel
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
 | 
			
		||||
    tf_path = os.path.abspath(tf_checkpoint_path)
 | 
			
		||||
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
 | 
			
		||||
    # Load weights from TF model
 | 
			
		||||
    init_vars = tf.train.list_variables(tf_path)
 | 
			
		||||
    names = []
 | 
			
		||||
    arrays = []
 | 
			
		||||
    layer_depth = []
 | 
			
		||||
    for full_name, shape in init_vars:
 | 
			
		||||
        # logger.info(f"Loading TF weight {name} with shape {shape}")
 | 
			
		||||
        name = full_name.split("/")
 | 
			
		||||
        if full_name == "_CHECKPOINTABLE_OBJECT_GRAPH" or name[0] in ["global_step", "save_counter"]:
 | 
			
		||||
            logger.info(f"Skipping non-model layer {full_name}")
 | 
			
		||||
            continue
 | 
			
		||||
        if "optimizer" in full_name:
 | 
			
		||||
            logger.info(f"Skipping optimization layer {full_name}")
 | 
			
		||||
            continue
 | 
			
		||||
        if name[0] == "model":
 | 
			
		||||
            # ignore initial 'model'
 | 
			
		||||
            name = name[1:]
 | 
			
		||||
        # figure out how many levels deep the name is
 | 
			
		||||
        depth = 0
 | 
			
		||||
        for _name in name:
 | 
			
		||||
            if _name.startswith("layer_with_weights"):
 | 
			
		||||
                depth += 1
 | 
			
		||||
            else:
 | 
			
		||||
                break
 | 
			
		||||
        layer_depth.append(depth)
 | 
			
		||||
        # read data
 | 
			
		||||
        array = tf.train.load_variable(tf_path, full_name)
 | 
			
		||||
        names.append("/".join(name))
 | 
			
		||||
        arrays.append(array)
 | 
			
		||||
    logger.info(f"Read a total of {len(arrays):,} layers")
 | 
			
		||||
 | 
			
		||||
    # Sanity check
 | 
			
		||||
    if len(set(layer_depth)) != 1:
 | 
			
		||||
        raise ValueError(f"Found layer names with different depths (layer depth {list(set(layer_depth))})")
 | 
			
		||||
    layer_depth = list(set(layer_depth))[0]
 | 
			
		||||
    if layer_depth != 1:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "The model contains more than just the embedding/encoder layers. This script does not handle MLM/NSP"
 | 
			
		||||
            " heads."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # convert layers
 | 
			
		||||
    logger.info("Converting weights...")
 | 
			
		||||
    for full_name, array in zip(names, arrays):
 | 
			
		||||
        name = full_name.split("/")
 | 
			
		||||
        pointer = model
 | 
			
		||||
        trace = []
 | 
			
		||||
        for i, m_name in enumerate(name):
 | 
			
		||||
            if m_name == ".ATTRIBUTES":
 | 
			
		||||
                # variable names end with .ATTRIBUTES/VARIABLE_VALUE
 | 
			
		||||
                break
 | 
			
		||||
            if m_name.startswith("layer_with_weights"):
 | 
			
		||||
                layer_num = int(m_name.split("-")[-1])
 | 
			
		||||
                if layer_num <= 2:
 | 
			
		||||
                    # embedding layers
 | 
			
		||||
                    # layer_num 0: word_embeddings
 | 
			
		||||
                    # layer_num 1: position_embeddings
 | 
			
		||||
                    # layer_num 2: token_type_embeddings
 | 
			
		||||
                    continue
 | 
			
		||||
                elif layer_num == 3:
 | 
			
		||||
                    # embedding LayerNorm
 | 
			
		||||
                    trace.extend(["embeddings", "LayerNorm"])
 | 
			
		||||
                    pointer = getattr(pointer, "embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
                elif layer_num > 3 and layer_num < config.num_hidden_layers + 4:
 | 
			
		||||
                    # encoder layers
 | 
			
		||||
                    trace.extend(["encoder", "layer", str(layer_num - 4)])
 | 
			
		||||
                    pointer = getattr(pointer, "encoder")
 | 
			
		||||
                    pointer = getattr(pointer, "layer")
 | 
			
		||||
                    pointer = pointer[layer_num - 4]
 | 
			
		||||
                elif layer_num == config.num_hidden_layers + 4:
 | 
			
		||||
                    # pooler layer
 | 
			
		||||
                    trace.extend(["pooler", "dense"])
 | 
			
		||||
                    pointer = getattr(pointer, "pooler")
 | 
			
		||||
                    pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "embeddings":
 | 
			
		||||
                trace.append("embeddings")
 | 
			
		||||
                pointer = getattr(pointer, "embeddings")
 | 
			
		||||
                if layer_num == 0:
 | 
			
		||||
                    trace.append("word_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "word_embeddings")
 | 
			
		||||
                elif layer_num == 1:
 | 
			
		||||
                    trace.append("position_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "position_embeddings")
 | 
			
		||||
                elif layer_num == 2:
 | 
			
		||||
                    trace.append("token_type_embeddings")
 | 
			
		||||
                    pointer = getattr(pointer, "token_type_embeddings")
 | 
			
		||||
                else:
 | 
			
		||||
                    raise ValueError(f"Unknown embedding layer with name {full_name}")
 | 
			
		||||
                trace.append("weight")
 | 
			
		||||
                pointer = getattr(pointer, "weight")
 | 
			
		||||
            elif m_name == "_attention_layer":
 | 
			
		||||
                # self-attention layer
 | 
			
		||||
                trace.extend(["attention", "self"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "self")
 | 
			
		||||
            elif m_name == "_attention_layer_norm":
 | 
			
		||||
                # output attention norm
 | 
			
		||||
                trace.extend(["attention", "output", "LayerNorm"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
            elif m_name == "_attention_output_dense":
 | 
			
		||||
                # output attention dense
 | 
			
		||||
                trace.extend(["attention", "output", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "attention")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_dense":
 | 
			
		||||
                # output dense
 | 
			
		||||
                trace.extend(["output", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_layer_norm":
 | 
			
		||||
                # output dense
 | 
			
		||||
                trace.extend(["output", "LayerNorm"])
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
                pointer = getattr(pointer, "LayerNorm")
 | 
			
		||||
            elif m_name == "_key_dense":
 | 
			
		||||
                # attention key
 | 
			
		||||
                trace.append("key")
 | 
			
		||||
                pointer = getattr(pointer, "key")
 | 
			
		||||
            elif m_name == "_query_dense":
 | 
			
		||||
                # attention query
 | 
			
		||||
                trace.append("query")
 | 
			
		||||
                pointer = getattr(pointer, "query")
 | 
			
		||||
            elif m_name == "_value_dense":
 | 
			
		||||
                # attention value
 | 
			
		||||
                trace.append("value")
 | 
			
		||||
                pointer = getattr(pointer, "value")
 | 
			
		||||
            elif m_name == "_intermediate_dense":
 | 
			
		||||
                # attention intermediate dense
 | 
			
		||||
                trace.extend(["intermediate", "dense"])
 | 
			
		||||
                pointer = getattr(pointer, "intermediate")
 | 
			
		||||
                pointer = getattr(pointer, "dense")
 | 
			
		||||
            elif m_name == "_output_layer_norm":
 | 
			
		||||
                # output layer norm
 | 
			
		||||
                trace.append("output")
 | 
			
		||||
                pointer = getattr(pointer, "output")
 | 
			
		||||
            # weights & biases
 | 
			
		||||
            elif m_name in ["bias", "beta"]:
 | 
			
		||||
                trace.append("bias")
 | 
			
		||||
                pointer = getattr(pointer, "bias")
 | 
			
		||||
            elif m_name in ["kernel", "gamma"]:
 | 
			
		||||
                trace.append("weight")
 | 
			
		||||
                pointer = getattr(pointer, "weight")
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning(f"Ignored {m_name}")
 | 
			
		||||
        # for certain layers reshape is necessary
 | 
			
		||||
        trace = ".".join(trace)
 | 
			
		||||
        if re.match(r"(\S+)\.attention\.self\.(key|value|query)\.(bias|weight)", trace) or re.match(
 | 
			
		||||
            r"(\S+)\.attention\.output\.dense\.weight", trace
 | 
			
		||||
        ):
 | 
			
		||||
            array = array.reshape(pointer.data.shape)
 | 
			
		||||
        if "kernel" in full_name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
        if pointer.shape == array.shape:
 | 
			
		||||
            pointer.data = torch.from_numpy(array)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Shape mismatch in layer {full_name}: Model expects shape {pointer.shape} but layer contains shape:"
 | 
			
		||||
                f" {array.shape}"
 | 
			
		||||
            )
 | 
			
		||||
        logger.info(f"Successfully set variable {full_name} to PyTorch layer {trace}")
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf2_checkpoint_to_pytorch(tf_checkpoint_path, config_path, pytorch_dump_path):
 | 
			
		||||
    # Instantiate model
 | 
			
		||||
    logger.info(f"Loading model based on config from {config_path}...")
 | 
			
		||||
    config = BertConfig.from_json_file(config_path)
 | 
			
		||||
    model = BertModel(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from checkpoint
 | 
			
		||||
    logger.info(f"Loading weights from checkpoint {tf_checkpoint_path}...")
 | 
			
		||||
    load_tf2_weights_in_bert(model, tf_checkpoint_path, config)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    logger.info(f"Saving PyTorch model to {pytorch_dump_path}...")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow 2.x checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model (must include filename).",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf2_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,62 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = BertConfig.from_json_file(bert_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = BertForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    torch.save(model.state_dict(), pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained BERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,112 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        model: BertModel Pytorch model instance to be converted
 | 
			
		||||
        ckpt_dir: Tensorflow model directory
 | 
			
		||||
        model_name: model name
 | 
			
		||||
 | 
			
		||||
    Currently supported HF models:
 | 
			
		||||
 | 
			
		||||
        - Y BertModel
 | 
			
		||||
        - N BertForMaskedLM
 | 
			
		||||
        - N BertForPreTraining
 | 
			
		||||
        - N BertForMultipleChoice
 | 
			
		||||
        - N BertForNextSentencePrediction
 | 
			
		||||
        - N BertForSequenceClassification
 | 
			
		||||
        - N BertForQuestionAnswering
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    tensors_to_transpose = ("dense.weight", "attention.self.query", "attention.self.key", "attention.self.value")
 | 
			
		||||
 | 
			
		||||
    var_map = (
 | 
			
		||||
        ("layer.", "layer_"),
 | 
			
		||||
        ("word_embeddings.weight", "word_embeddings"),
 | 
			
		||||
        ("position_embeddings.weight", "position_embeddings"),
 | 
			
		||||
        ("token_type_embeddings.weight", "token_type_embeddings"),
 | 
			
		||||
        (".", "/"),
 | 
			
		||||
        ("LayerNorm/weight", "LayerNorm/gamma"),
 | 
			
		||||
        ("LayerNorm/bias", "LayerNorm/beta"),
 | 
			
		||||
        ("weight", "kernel"),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not os.path.isdir(ckpt_dir):
 | 
			
		||||
        os.makedirs(ckpt_dir)
 | 
			
		||||
 | 
			
		||||
    state_dict = model.state_dict()
 | 
			
		||||
 | 
			
		||||
    def to_tf_var_name(name: str):
 | 
			
		||||
        for patt, repl in iter(var_map):
 | 
			
		||||
            name = name.replace(patt, repl)
 | 
			
		||||
        return f"bert/{name}"
 | 
			
		||||
 | 
			
		||||
    def create_tf_var(tensor: np.ndarray, name: str, session: tf.Session):
 | 
			
		||||
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
 | 
			
		||||
        tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer())
 | 
			
		||||
        session.run(tf.variables_initializer([tf_var]))
 | 
			
		||||
        session.run(tf_var)
 | 
			
		||||
        return tf_var
 | 
			
		||||
 | 
			
		||||
    tf.reset_default_graph()
 | 
			
		||||
    with tf.Session() as session:
 | 
			
		||||
        for var_name in state_dict:
 | 
			
		||||
            tf_name = to_tf_var_name(var_name)
 | 
			
		||||
            torch_tensor = state_dict[var_name].numpy()
 | 
			
		||||
            if any(x in var_name for x in tensors_to_transpose):
 | 
			
		||||
                torch_tensor = torch_tensor.T
 | 
			
		||||
            tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session)
 | 
			
		||||
            tf_var.assign(tf.cast(torch_tensor, tf_var.dtype))
 | 
			
		||||
            tf_weight = session.run(tf_var)
 | 
			
		||||
            print(f"Successfully created {tf_name}: {np.allclose(tf_weight, torch_tensor)}")
 | 
			
		||||
 | 
			
		||||
        saver = tf.train.Saver(tf.trainable_variables())
 | 
			
		||||
        saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(raw_args=None):
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--model_name", type=str, required=True, help="model name e.g. google-bert/bert-base-uncased")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--cache_dir", type=str, default=None, required=False, help="Directory containing pytorch model"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--pytorch_model_path", type=str, required=True, help="/path/to/<pytorch-model-name>.bin")
 | 
			
		||||
    parser.add_argument("--tf_cache_dir", type=str, required=True, help="Directory in which to save tensorflow model")
 | 
			
		||||
    args = parser.parse_args(raw_args)
 | 
			
		||||
 | 
			
		||||
    model = BertModel.from_pretrained(
 | 
			
		||||
        pretrained_model_name_or_path=args.model_name,
 | 
			
		||||
        state_dict=torch.load(args.pytorch_model_path),
 | 
			
		||||
        cache_dir=args.cache_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_cache_dir, model_name=args.model_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,188 +0,0 @@
 | 
			
		||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This script converts a lm-head checkpoint from the "Token Dropping" implementation into a PyTorch-compatible BERT
 | 
			
		||||
model. The official implementation of "Token Dropping" can be found in the TensorFlow Models repository:
 | 
			
		||||
 | 
			
		||||
https://github.com/tensorflow/models/tree/master/official/projects/token_dropping
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertForMaskedLM
 | 
			
		||||
from transformers.models.bert.modeling_bert import (
 | 
			
		||||
    BertIntermediate,
 | 
			
		||||
    BertLayer,
 | 
			
		||||
    BertOutput,
 | 
			
		||||
    BertPooler,
 | 
			
		||||
    BertSelfAttention,
 | 
			
		||||
    BertSelfOutput,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_checkpoint_to_pytorch(tf_checkpoint_path: str, config_path: str, pytorch_dump_path: str):
 | 
			
		||||
    def get_masked_lm_array(name: str):
 | 
			
		||||
        full_name = f"masked_lm/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_array(name: str):
 | 
			
		||||
        full_name = f"encoder/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_layer_array(layer_index: int, name: str):
 | 
			
		||||
        full_name = f"encoder/_transformer_layers/{layer_index}/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    def get_encoder_attention_layer_array(layer_index: int, name: str, orginal_shape):
 | 
			
		||||
        full_name = f"encoder/_transformer_layers/{layer_index}/_attention_layer/{name}/.ATTRIBUTES/VARIABLE_VALUE"
 | 
			
		||||
        array = tf.train.load_variable(tf_checkpoint_path, full_name)
 | 
			
		||||
        array = array.reshape(orginal_shape)
 | 
			
		||||
 | 
			
		||||
        if "kernel" in name:
 | 
			
		||||
            array = array.transpose()
 | 
			
		||||
 | 
			
		||||
        return torch.from_numpy(array)
 | 
			
		||||
 | 
			
		||||
    print(f"Loading model based on config from {config_path}...")
 | 
			
		||||
    config = BertConfig.from_json_file(config_path)
 | 
			
		||||
    model = BertForMaskedLM(config)
 | 
			
		||||
 | 
			
		||||
    # Layers
 | 
			
		||||
    for layer_index in range(0, config.num_hidden_layers):
 | 
			
		||||
        layer: BertLayer = model.bert.encoder.layer[layer_index]
 | 
			
		||||
 | 
			
		||||
        # Self-attention
 | 
			
		||||
        self_attn: BertSelfAttention = layer.attention.self
 | 
			
		||||
 | 
			
		||||
        self_attn.query.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_query_dense/kernel", self_attn.query.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.query.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_query_dense/bias", self_attn.query.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.key.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_key_dense/kernel", self_attn.key.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.key.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_key_dense/bias", self_attn.key.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_value_dense/kernel", self_attn.value.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_value_dense/bias", self_attn.value.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Self-attention Output
 | 
			
		||||
        self_output: BertSelfOutput = layer.attention.output
 | 
			
		||||
 | 
			
		||||
        self_output.dense.weight.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_output_dense/kernel", self_output.dense.weight.data.shape
 | 
			
		||||
        )
 | 
			
		||||
        self_output.dense.bias.data = get_encoder_attention_layer_array(
 | 
			
		||||
            layer_index, "_output_dense/bias", self_output.dense.bias.data.shape
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/gamma")
 | 
			
		||||
        self_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_attention_layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
        # Intermediate
 | 
			
		||||
        intermediate: BertIntermediate = layer.intermediate
 | 
			
		||||
 | 
			
		||||
        intermediate.dense.weight.data = get_encoder_layer_array(layer_index, "_intermediate_dense/kernel")
 | 
			
		||||
        intermediate.dense.bias.data = get_encoder_layer_array(layer_index, "_intermediate_dense/bias")
 | 
			
		||||
 | 
			
		||||
        # Output
 | 
			
		||||
        bert_output: BertOutput = layer.output
 | 
			
		||||
 | 
			
		||||
        bert_output.dense.weight.data = get_encoder_layer_array(layer_index, "_output_dense/kernel")
 | 
			
		||||
        bert_output.dense.bias.data = get_encoder_layer_array(layer_index, "_output_dense/bias")
 | 
			
		||||
 | 
			
		||||
        bert_output.LayerNorm.weight.data = get_encoder_layer_array(layer_index, "_output_layer_norm/gamma")
 | 
			
		||||
        bert_output.LayerNorm.bias.data = get_encoder_layer_array(layer_index, "_output_layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
    # Embeddings
 | 
			
		||||
    model.bert.embeddings.position_embeddings.weight.data = get_encoder_array("_position_embedding_layer/embeddings")
 | 
			
		||||
    model.bert.embeddings.token_type_embeddings.weight.data = get_encoder_array("_type_embedding_layer/embeddings")
 | 
			
		||||
    model.bert.embeddings.LayerNorm.weight.data = get_encoder_array("_embedding_norm_layer/gamma")
 | 
			
		||||
    model.bert.embeddings.LayerNorm.bias.data = get_encoder_array("_embedding_norm_layer/beta")
 | 
			
		||||
 | 
			
		||||
    # LM Head
 | 
			
		||||
    lm_head = model.cls.predictions.transform
 | 
			
		||||
 | 
			
		||||
    lm_head.dense.weight.data = get_masked_lm_array("dense/kernel")
 | 
			
		||||
    lm_head.dense.bias.data = get_masked_lm_array("dense/bias")
 | 
			
		||||
 | 
			
		||||
    lm_head.LayerNorm.weight.data = get_masked_lm_array("layer_norm/gamma")
 | 
			
		||||
    lm_head.LayerNorm.bias.data = get_masked_lm_array("layer_norm/beta")
 | 
			
		||||
 | 
			
		||||
    model.bert.embeddings.word_embeddings.weight.data = get_masked_lm_array("embedding_table")
 | 
			
		||||
 | 
			
		||||
    # Pooling
 | 
			
		||||
    model.bert.pooler = BertPooler(config=config)
 | 
			
		||||
    model.bert.pooler.dense.weight.data: BertPooler = get_encoder_array("_pooler_layer/kernel")
 | 
			
		||||
    model.bert.pooler.dense.bias.data: BertPooler = get_encoder_array("_pooler_layer/bias")
 | 
			
		||||
 | 
			
		||||
    # Export final model
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    # Integration test - should load without any errors ;)
 | 
			
		||||
    new_model = BertForMaskedLM.from_pretrained(pytorch_dump_path)
 | 
			
		||||
    print(new_model.eval())
 | 
			
		||||
 | 
			
		||||
    print("Model conversion was done sucessfully!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", type=str, required=True, help="Path to the TensorFlow Token Dropping checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bert_config_file",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="The config json file corresponding to the BERT model. This specifies the model architecture.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,69 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BigBird checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import BigBirdConfig, BigBirdForPreTraining, BigBirdForQuestionAnswering, load_tf_weights_in_big_bird
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, big_bird_config_file, pytorch_dump_path, is_trivia_qa):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = BigBirdConfig.from_json_file(big_bird_config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
 | 
			
		||||
    if is_trivia_qa:
 | 
			
		||||
        model = BigBirdForQuestionAnswering(config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BigBirdForPreTraining(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=is_trivia_qa)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--big_bird_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained BERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--is_trivia_qa", action="store_true", help="Whether to convert a model with a trivia_qa head."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(
 | 
			
		||||
        args.tf_checkpoint_path, args.big_bird_config_file, args.pytorch_dump_path, args.is_trivia_qa
 | 
			
		||||
    )
 | 
			
		||||
@ -1,170 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
from typing import Dict
 | 
			
		||||
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
INIT_COMMON = [
 | 
			
		||||
    # tf -> hf
 | 
			
		||||
    ("/", "."),
 | 
			
		||||
    ("layer_", "layers."),
 | 
			
		||||
    ("kernel", "weight"),
 | 
			
		||||
    ("beta", "bias"),
 | 
			
		||||
    ("gamma", "weight"),
 | 
			
		||||
    ("pegasus", "model"),
 | 
			
		||||
]
 | 
			
		||||
END_COMMON = [
 | 
			
		||||
    (".output.dense", ".fc2"),
 | 
			
		||||
    ("intermediate.LayerNorm", "final_layer_norm"),
 | 
			
		||||
    ("intermediate.dense", "fc1"),
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
DECODER_PATTERNS = (
 | 
			
		||||
    INIT_COMMON
 | 
			
		||||
    + [
 | 
			
		||||
        ("attention.self.LayerNorm", "self_attn_layer_norm"),
 | 
			
		||||
        ("attention.output.dense", "self_attn.out_proj"),
 | 
			
		||||
        ("attention.self", "self_attn"),
 | 
			
		||||
        ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"),
 | 
			
		||||
        ("attention.encdec_output.dense", "encoder_attn.out_proj"),
 | 
			
		||||
        ("attention.encdec", "encoder_attn"),
 | 
			
		||||
        ("key", "k_proj"),
 | 
			
		||||
        ("value", "v_proj"),
 | 
			
		||||
        ("query", "q_proj"),
 | 
			
		||||
        ("decoder.LayerNorm", "decoder.layernorm_embedding"),
 | 
			
		||||
    ]
 | 
			
		||||
    + END_COMMON
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
REMAINING_PATTERNS = (
 | 
			
		||||
    INIT_COMMON
 | 
			
		||||
    + [
 | 
			
		||||
        ("embeddings.word_embeddings", "shared.weight"),
 | 
			
		||||
        ("embeddings.position_embeddings", "embed_positions.weight"),
 | 
			
		||||
        ("attention.self.LayerNorm", "self_attn_layer_norm"),
 | 
			
		||||
        ("attention.output.dense", "self_attn.output"),
 | 
			
		||||
        ("attention.self", "self_attn.self"),
 | 
			
		||||
        ("encoder.LayerNorm", "encoder.layernorm_embedding"),
 | 
			
		||||
    ]
 | 
			
		||||
    + END_COMMON
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
KEYS_TO_IGNORE = [
 | 
			
		||||
    "encdec/key/bias",
 | 
			
		||||
    "encdec/query/bias",
 | 
			
		||||
    "encdec/value/bias",
 | 
			
		||||
    "self/key/bias",
 | 
			
		||||
    "self/query/bias",
 | 
			
		||||
    "self/value/bias",
 | 
			
		||||
    "encdec_output/dense/bias",
 | 
			
		||||
    "attention/output/dense/bias",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_key(k, patterns):
 | 
			
		||||
    for tf_name, hf_name in patterns:
 | 
			
		||||
        k = k.replace(tf_name, hf_name)
 | 
			
		||||
    return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration:
 | 
			
		||||
    cfg = BigBirdPegasusConfig(**config_update)
 | 
			
		||||
    torch_model = BigBirdPegasusForConditionalGeneration(cfg)
 | 
			
		||||
    state_dict = torch_model.state_dict()
 | 
			
		||||
    mapping = {}
 | 
			
		||||
 | 
			
		||||
    # separating decoder weights
 | 
			
		||||
    decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")}
 | 
			
		||||
    remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")}
 | 
			
		||||
 | 
			
		||||
    for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"):
 | 
			
		||||
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
 | 
			
		||||
        if any(conditions):
 | 
			
		||||
            continue
 | 
			
		||||
        patterns = DECODER_PATTERNS
 | 
			
		||||
        new_k = rename_state_dict_key(k, patterns)
 | 
			
		||||
        if new_k not in state_dict:
 | 
			
		||||
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
 | 
			
		||||
        if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
 | 
			
		||||
            v = v.T
 | 
			
		||||
        mapping[new_k] = torch.from_numpy(v)
 | 
			
		||||
        assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
 | 
			
		||||
 | 
			
		||||
    for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"):
 | 
			
		||||
        conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE]
 | 
			
		||||
        if any(conditions):
 | 
			
		||||
            continue
 | 
			
		||||
        patterns = REMAINING_PATTERNS
 | 
			
		||||
        new_k = rename_state_dict_key(k, patterns)
 | 
			
		||||
        if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings":
 | 
			
		||||
            raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})")
 | 
			
		||||
        if any(True if i in k else False for i in ["dense", "query", "key", "value"]):
 | 
			
		||||
            v = v.T
 | 
			
		||||
        mapping[new_k] = torch.from_numpy(v)
 | 
			
		||||
        if k != "pegasus/embeddings/position_embeddings":
 | 
			
		||||
            assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}"
 | 
			
		||||
 | 
			
		||||
    mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"]
 | 
			
		||||
    mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight")
 | 
			
		||||
    missing, extra = torch_model.load_state_dict(mapping, strict=False)
 | 
			
		||||
    unexpected_missing = [
 | 
			
		||||
        k
 | 
			
		||||
        for k in missing
 | 
			
		||||
        if k
 | 
			
		||||
        not in [
 | 
			
		||||
            "final_logits_bias",
 | 
			
		||||
            "model.encoder.embed_tokens.weight",
 | 
			
		||||
            "model.decoder.embed_tokens.weight",
 | 
			
		||||
            "lm_head.weight",
 | 
			
		||||
        ]
 | 
			
		||||
    ]
 | 
			
		||||
    assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}"
 | 
			
		||||
    assert extra == [], f"no matches found for the following tf keys {extra}"
 | 
			
		||||
    return torch_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_tf_weights_as_numpy(path) -> Dict:
 | 
			
		||||
    init_vars = tf.train.list_variables(path)
 | 
			
		||||
    tf_weights = {}
 | 
			
		||||
    ignore_name = ["global_step"]
 | 
			
		||||
    for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"):
 | 
			
		||||
        skip_key = any(pat in name for pat in ignore_name)
 | 
			
		||||
        if skip_key:
 | 
			
		||||
            continue
 | 
			
		||||
        array = tf.train.load_variable(path, name)
 | 
			
		||||
        tf_weights[name] = array
 | 
			
		||||
    return tf_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict):
 | 
			
		||||
    tf_weights = get_tf_weights_as_numpy(ckpt_path)
 | 
			
		||||
    torch_model = convert_bigbird_pegasus(tf_weights, config_update)
 | 
			
		||||
    torch_model.save_pretrained(save_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables")
 | 
			
		||||
    parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    config_update = {}
 | 
			
		||||
    convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update)
 | 
			
		||||
@ -1,292 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BioGptConfig, BioGptForCausalLM
 | 
			
		||||
from transformers.models.biogpt.tokenization_biogpt import VOCAB_FILES_NAMES
 | 
			
		||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
 | 
			
		||||
from transformers.utils import WEIGHTS_NAME, logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_warning()
 | 
			
		||||
 | 
			
		||||
json_indent = 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# modified from https://github.com/facebookresearch/fairseq/blob/dd74992d0d143155998e9ed4076826bcea80fb06/fairseq/data/dictionary.py#L18
 | 
			
		||||
class Dictionary:
 | 
			
		||||
    """A mapping from symbols to consecutive integers"""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # begin keyword-only arguments
 | 
			
		||||
        bos="<s>",
 | 
			
		||||
        pad="<pad>",
 | 
			
		||||
        eos="</s>",
 | 
			
		||||
        unk="<unk>",
 | 
			
		||||
        extra_special_symbols=None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
 | 
			
		||||
        self.symbols = []
 | 
			
		||||
        self.count = []
 | 
			
		||||
        self.indices = {}
 | 
			
		||||
        self.bos_index = self.add_symbol(bos)
 | 
			
		||||
        self.pad_index = self.add_symbol(pad)
 | 
			
		||||
        self.eos_index = self.add_symbol(eos)
 | 
			
		||||
        self.unk_index = self.add_symbol(unk)
 | 
			
		||||
        if extra_special_symbols:
 | 
			
		||||
            for s in extra_special_symbols:
 | 
			
		||||
                self.add_symbol(s)
 | 
			
		||||
        self.nspecial = len(self.symbols)
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        return self.indices == other.indices
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        if idx < len(self.symbols):
 | 
			
		||||
            return self.symbols[idx]
 | 
			
		||||
        return self.unk_word
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        """Returns the number of symbols in the dictionary"""
 | 
			
		||||
        return len(self.symbols)
 | 
			
		||||
 | 
			
		||||
    def __contains__(self, sym):
 | 
			
		||||
        return sym in self.indices
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load(cls, f):
 | 
			
		||||
        """Loads the dictionary from a text file with the format:
 | 
			
		||||
 | 
			
		||||
        ```
 | 
			
		||||
        <symbol0> <count0>
 | 
			
		||||
        <symbol1> <count1>
 | 
			
		||||
        ...
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
        d = cls()
 | 
			
		||||
        d.add_from_file(f)
 | 
			
		||||
        return d
 | 
			
		||||
 | 
			
		||||
    def add_symbol(self, word, n=1, overwrite=False):
 | 
			
		||||
        """Adds a word to the dictionary"""
 | 
			
		||||
        if word in self.indices and not overwrite:
 | 
			
		||||
            idx = self.indices[word]
 | 
			
		||||
            self.count[idx] = self.count[idx] + n
 | 
			
		||||
            return idx
 | 
			
		||||
        else:
 | 
			
		||||
            idx = len(self.symbols)
 | 
			
		||||
            self.indices[word] = idx
 | 
			
		||||
            self.symbols.append(word)
 | 
			
		||||
            self.count.append(n)
 | 
			
		||||
            return idx
 | 
			
		||||
 | 
			
		||||
    def _load_meta(self, lines):
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def add_from_file(self, f):
 | 
			
		||||
        """
 | 
			
		||||
        Loads a pre-existing dictionary from a text file and adds its symbols to this instance.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(f, str):
 | 
			
		||||
            try:
 | 
			
		||||
                with open(f, "r", encoding="utf-8") as fd:
 | 
			
		||||
                    self.add_from_file(fd)
 | 
			
		||||
            except FileNotFoundError as fnfe:
 | 
			
		||||
                raise fnfe
 | 
			
		||||
            except UnicodeError:
 | 
			
		||||
                raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        lines = f.readlines()
 | 
			
		||||
        indices_start_line = self._load_meta(lines)
 | 
			
		||||
 | 
			
		||||
        for line in lines[indices_start_line:]:
 | 
			
		||||
            try:
 | 
			
		||||
                line, field = line.rstrip().rsplit(" ", 1)
 | 
			
		||||
                if field == "#fairseq:overwrite":
 | 
			
		||||
                    overwrite = True
 | 
			
		||||
                    line, field = line.rsplit(" ", 1)
 | 
			
		||||
                else:
 | 
			
		||||
                    overwrite = False
 | 
			
		||||
                count = int(field)
 | 
			
		||||
                word = line
 | 
			
		||||
                if word in self and not overwrite:
 | 
			
		||||
                    raise RuntimeError(
 | 
			
		||||
                        "Duplicate word found when loading Dictionary: '{}'. "
 | 
			
		||||
                        "Duplicate words can overwrite earlier ones by adding the "
 | 
			
		||||
                        "#fairseq:overwrite flag at the end of the corresponding row "
 | 
			
		||||
                        "in the dictionary file. If using the Camembert model, please "
 | 
			
		||||
                        "download an updated copy of the model file.".format(word)
 | 
			
		||||
                    )
 | 
			
		||||
                self.add_symbol(word, n=count, overwrite=overwrite)
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                raise ValueError("Incorrect dictionary format, expected '<token> <cnt> [flags]'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rewrite_dict_keys(d):
 | 
			
		||||
    # (1) remove word breaking symbol, (2) add word ending symbol where the word is not broken up,
 | 
			
		||||
    # e.g.: d = {'le@@': 5, 'tt@@': 6, 'er': 7} => {'le': 5, 'tt': 6, 'er</w>': 7}
 | 
			
		||||
    d2 = dict((re.sub(r"@@$", "", k), v) if k.endswith("@@") else (re.sub(r"$", "</w>", k), v) for k, v in d.items())
 | 
			
		||||
    keep_keys = "<s> <pad> </s> <unk>".split()
 | 
			
		||||
    # restore the special tokens
 | 
			
		||||
    for k in keep_keys:
 | 
			
		||||
        del d2[f"{k}</w>"]
 | 
			
		||||
        d2[k] = d[k]  # restore
 | 
			
		||||
    return d2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_biogpt_checkpoint_to_pytorch(biogpt_checkpoint_path, pytorch_dump_folder_path):
 | 
			
		||||
    # prep
 | 
			
		||||
    if not os.path.exists(biogpt_checkpoint_path):
 | 
			
		||||
        raise ValueError(f"path {biogpt_checkpoint_path} does not exist!")
 | 
			
		||||
    os.makedirs(pytorch_dump_folder_path, exist_ok=True)
 | 
			
		||||
    print(f"Writing results to {pytorch_dump_folder_path}")
 | 
			
		||||
 | 
			
		||||
    # handle various types of models
 | 
			
		||||
 | 
			
		||||
    checkpoint_file = os.path.join(biogpt_checkpoint_path, "checkpoint.pt")
 | 
			
		||||
    if not os.path.isfile(checkpoint_file):
 | 
			
		||||
        raise ValueError(f"path to the file {checkpoint_file} does not exist!")
 | 
			
		||||
    chkpt = torch.load(checkpoint_file, map_location="cpu")
 | 
			
		||||
 | 
			
		||||
    args = chkpt["cfg"]["model"]
 | 
			
		||||
 | 
			
		||||
    # dicts
 | 
			
		||||
    dict_file = os.path.join(biogpt_checkpoint_path, "dict.txt")
 | 
			
		||||
    if not os.path.isfile(dict_file):
 | 
			
		||||
        raise ValueError(f"path to the file {dict_file} does not exist!")
 | 
			
		||||
    src_dict = Dictionary.load(dict_file)
 | 
			
		||||
    src_vocab = rewrite_dict_keys(src_dict.indices)
 | 
			
		||||
    src_vocab_size = len(src_vocab)
 | 
			
		||||
    src_vocab_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["vocab_file"])
 | 
			
		||||
    print(f"Generating {src_vocab_file} of {src_vocab_size} records")
 | 
			
		||||
    with open(src_vocab_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # merges_file (bpecodes)
 | 
			
		||||
    bpecodes_file = os.path.join(biogpt_checkpoint_path, "bpecodes")
 | 
			
		||||
    if not os.path.isfile(bpecodes_file):
 | 
			
		||||
        raise ValueError(f"path to the file {bpecodes_file} does not exist!")
 | 
			
		||||
 | 
			
		||||
    merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
 | 
			
		||||
    shutil.copyfile(bpecodes_file, merges_file)
 | 
			
		||||
 | 
			
		||||
    # model config
 | 
			
		||||
    biogpt_model_config_file = os.path.join(pytorch_dump_folder_path, "config.json")
 | 
			
		||||
 | 
			
		||||
    model_conf = {
 | 
			
		||||
        "activation_dropout": args["activation_dropout"],
 | 
			
		||||
        "architectures": ["BioGptForCausalLM"],
 | 
			
		||||
        "attention_probs_dropout_prob": args["attention_dropout"],
 | 
			
		||||
        "bos_token_id": 0,
 | 
			
		||||
        "eos_token_id": 2,
 | 
			
		||||
        "hidden_act": args["activation_fn"],
 | 
			
		||||
        "hidden_dropout_prob": args["dropout"],
 | 
			
		||||
        "hidden_size": args["decoder_embed_dim"],
 | 
			
		||||
        "initializer_range": 0.02,
 | 
			
		||||
        "intermediate_size": args["decoder_ffn_embed_dim"],
 | 
			
		||||
        "layer_norm_eps": 1e-12,
 | 
			
		||||
        "layerdrop": args["decoder_layerdrop"],
 | 
			
		||||
        "max_position_embeddings": args["max_target_positions"],
 | 
			
		||||
        "model_type": "biogpt",
 | 
			
		||||
        "num_attention_heads": args["decoder_attention_heads"],
 | 
			
		||||
        "num_hidden_layers": args["decoder_layers"],
 | 
			
		||||
        "pad_token_id": 1,
 | 
			
		||||
        "scale_embedding": not args["no_scale_embedding"],
 | 
			
		||||
        "tie_word_embeddings": args["share_decoder_input_output_embed"],
 | 
			
		||||
        "vocab_size": src_vocab_size,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # good hparam defaults to start with
 | 
			
		||||
 | 
			
		||||
    print(f"Generating {biogpt_model_config_file}")
 | 
			
		||||
    with open(biogpt_model_config_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(model_conf, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # tokenizer config
 | 
			
		||||
    biogpt_tokenizer_config_file = os.path.join(pytorch_dump_folder_path, TOKENIZER_CONFIG_FILE)
 | 
			
		||||
 | 
			
		||||
    tokenizer_conf = {
 | 
			
		||||
        "bos_token": "<s>",
 | 
			
		||||
        "eos_token": "</s>",
 | 
			
		||||
        "model_max_length": 1024,
 | 
			
		||||
        "pad_token": "<pad>",
 | 
			
		||||
        "special_tokens_map_file": None,
 | 
			
		||||
        "tokenizer_class": "BioGptTokenizer",
 | 
			
		||||
        "unk_token": "<unk>",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    print(f"Generating {biogpt_tokenizer_config_file}")
 | 
			
		||||
    with open(biogpt_tokenizer_config_file, "w", encoding="utf-8") as f:
 | 
			
		||||
        f.write(json.dumps(tokenizer_conf, ensure_ascii=False, indent=json_indent))
 | 
			
		||||
 | 
			
		||||
    # model
 | 
			
		||||
    model_state_dict = chkpt["model"]
 | 
			
		||||
 | 
			
		||||
    # remove unneeded keys
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "decoder.version",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        model_state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
    layer_names = list(model_state_dict.keys())
 | 
			
		||||
    for layer_name in layer_names:
 | 
			
		||||
        if layer_name.endswith("output_projection.weight"):
 | 
			
		||||
            model_state_dict[layer_name.replace("decoder.", "")] = model_state_dict.pop(layer_name)
 | 
			
		||||
        else:
 | 
			
		||||
            model_state_dict[layer_name.replace("decoder", "biogpt")] = model_state_dict.pop(layer_name)
 | 
			
		||||
 | 
			
		||||
    config = BioGptConfig.from_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    model_new = BioGptForCausalLM(config)
 | 
			
		||||
 | 
			
		||||
    # check that it loads ok
 | 
			
		||||
    model_new.load_state_dict(model_state_dict)
 | 
			
		||||
 | 
			
		||||
    # save
 | 
			
		||||
    pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
 | 
			
		||||
    print(f"Generating {pytorch_weights_dump_path}")
 | 
			
		||||
    torch.save(model_state_dict, pytorch_weights_dump_path)
 | 
			
		||||
 | 
			
		||||
    print("Conversion is done!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--biogpt_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "Path to the official PyTorch checkpoint file which is expected to reside in the dump dir with dicts,"
 | 
			
		||||
            " bpecodes, etc."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_biogpt_checkpoint_to_pytorch(args.biogpt_checkpoint_path, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,177 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BiT checkpoints from the timm library."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from timm import create_model
 | 
			
		||||
from timm.data import resolve_data_config
 | 
			
		||||
from timm.data.transforms_factory import create_transform
 | 
			
		||||
 | 
			
		||||
from transformers import BitConfig, BitForImageClassification, BitImageProcessor
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_config(model_name):
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "imagenet-1k-id2label.json"
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    conv_layer = "std_conv" if "bit" in model_name else False
 | 
			
		||||
 | 
			
		||||
    # note that when using BiT as backbone for ViT-hybrid checkpoints,
 | 
			
		||||
    # one needs to additionally set config.layer_type = "bottleneck", config.stem_type = "same",
 | 
			
		||||
    # config.conv_layer = "std_conv_same"
 | 
			
		||||
    config = BitConfig(
 | 
			
		||||
        conv_layer=conv_layer,
 | 
			
		||||
        num_labels=1000,
 | 
			
		||||
        id2label=id2label,
 | 
			
		||||
        label2id=label2id,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "stem.conv" in name:
 | 
			
		||||
        name = name.replace("stem.conv", "bit.embedder.convolution")
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "layers")
 | 
			
		||||
    if "head.fc" in name:
 | 
			
		||||
        name = name.replace("head.fc", "classifier.1")
 | 
			
		||||
    if name.startswith("norm"):
 | 
			
		||||
        name = "bit." + name
 | 
			
		||||
    if "bit" not in name and "classifier" not in name:
 | 
			
		||||
        name = "bit.encoder." + name
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BiT structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define default BiT configuration
 | 
			
		||||
    config = get_config(model_name)
 | 
			
		||||
 | 
			
		||||
    # load original model from timm
 | 
			
		||||
    timm_model = create_model(model_name, pretrained=True)
 | 
			
		||||
    timm_model.eval()
 | 
			
		||||
 | 
			
		||||
    # load state_dict of original model
 | 
			
		||||
    state_dict = timm_model.state_dict()
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val.squeeze() if "head" in key else val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = BitForImageClassification(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    # create image processor
 | 
			
		||||
    transform = create_transform(**resolve_data_config({}, model=timm_model))
 | 
			
		||||
    timm_transforms = transform.transforms
 | 
			
		||||
 | 
			
		||||
    pillow_resamplings = {
 | 
			
		||||
        "bilinear": PILImageResampling.BILINEAR,
 | 
			
		||||
        "bicubic": PILImageResampling.BICUBIC,
 | 
			
		||||
        "nearest": PILImageResampling.NEAREST,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    processor = BitImageProcessor(
 | 
			
		||||
        do_resize=True,
 | 
			
		||||
        size={"shortest_edge": timm_transforms[0].size},
 | 
			
		||||
        resample=pillow_resamplings[timm_transforms[0].interpolation.value],
 | 
			
		||||
        do_center_crop=True,
 | 
			
		||||
        crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
 | 
			
		||||
        do_normalize=True,
 | 
			
		||||
        image_mean=timm_transforms[-1].mean.tolist(),
 | 
			
		||||
        image_std=timm_transforms[-1].std.tolist(),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    image = prepare_img()
 | 
			
		||||
    timm_pixel_values = transform(image).unsqueeze(0)
 | 
			
		||||
    pixel_values = processor(image, return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    # verify pixel values
 | 
			
		||||
    assert torch.allclose(timm_pixel_values, pixel_values)
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = model(pixel_values)
 | 
			
		||||
        logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    print("Logits:", logits[0, :3])
 | 
			
		||||
    print("Predicted class:", model.config.id2label[logits.argmax(-1).item()])
 | 
			
		||||
    timm_logits = timm_model(pixel_values)
 | 
			
		||||
    assert timm_logits.shape == outputs.logits.shape
 | 
			
		||||
    assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing model {model_name} and processor to the hub")
 | 
			
		||||
        model.push_to_hub(f"ybelkada/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"ybelkada/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="resnetv2_50x1_bitm",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the BiT timm model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to push the model to the hub.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,114 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Blenderbot checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
PATTERNS = [
 | 
			
		||||
    ["attention", "attn"],
 | 
			
		||||
    ["encoder_attention", "encoder_attn"],
 | 
			
		||||
    ["q_lin", "q_proj"],
 | 
			
		||||
    ["k_lin", "k_proj"],
 | 
			
		||||
    ["v_lin", "v_proj"],
 | 
			
		||||
    ["out_lin", "out_proj"],
 | 
			
		||||
    ["norm_embeddings", "layernorm_embedding"],
 | 
			
		||||
    ["position_embeddings", "embed_positions"],
 | 
			
		||||
    ["embeddings", "embed_tokens"],
 | 
			
		||||
    ["ffn.lin", "fc"],
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_key(k):
 | 
			
		||||
    if k == "embeddings.weight":
 | 
			
		||||
        return "shared.weight"
 | 
			
		||||
 | 
			
		||||
    for parlai_name, hf_name in PATTERNS:
 | 
			
		||||
        k = k.replace(parlai_name, hf_name)
 | 
			
		||||
 | 
			
		||||
    if k.startswith("encoder"):
 | 
			
		||||
        k = k.replace(".attn", ".self_attn")
 | 
			
		||||
        k = k.replace("norm1", "self_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm2", "final_layer_norm")
 | 
			
		||||
    elif k.startswith("decoder"):
 | 
			
		||||
        k = k.replace("norm1", "self_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm2", "encoder_attn_layer_norm")
 | 
			
		||||
        k = k.replace("norm3", "final_layer_norm")
 | 
			
		||||
    return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_layernorm_keys(sd):
 | 
			
		||||
    keys = [
 | 
			
		||||
        "model.encoder.layernorm_embedding.weight",
 | 
			
		||||
        "model.encoder.layernorm_embedding.bias",
 | 
			
		||||
        "model.decoder.layernorm_embedding.weight",
 | 
			
		||||
        "model.decoder.layernorm_embedding.bias",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in keys:
 | 
			
		||||
        v = sd.pop(k)
 | 
			
		||||
        new_k = k.replace("layernorm_embedding", "layer_norm")
 | 
			
		||||
        assert new_k not in sd
 | 
			
		||||
        sd[new_k] = v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
IGNORE_KEYS = ["START"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_parlai_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_json_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our BERT structure.
 | 
			
		||||
    """
 | 
			
		||||
    model = torch.load(checkpoint_path, map_location="cpu")
 | 
			
		||||
    sd = model["model"]
 | 
			
		||||
    cfg = BlenderbotConfig.from_json_file(config_json_path)
 | 
			
		||||
    m = BlenderbotForConditionalGeneration(cfg)
 | 
			
		||||
    valid_keys = m.model.state_dict().keys()
 | 
			
		||||
    failures = []
 | 
			
		||||
    mapping = {}
 | 
			
		||||
    for k, v in sd.items():
 | 
			
		||||
        if k in IGNORE_KEYS:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        new_k = rename_state_dict_key(k)
 | 
			
		||||
        if new_k not in valid_keys:
 | 
			
		||||
            failures.append([k, new_k])
 | 
			
		||||
        else:
 | 
			
		||||
            mapping[new_k] = v
 | 
			
		||||
    if cfg.normalize_before:  # Blenderbot-3B checkpoints. Rename layernorm_embedding -> layer_norm
 | 
			
		||||
        rename_layernorm_keys(sd)
 | 
			
		||||
    m.model.load_state_dict(mapping, strict=True)
 | 
			
		||||
    m.half()
 | 
			
		||||
    m.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument("--src_path", type=str, help="like blenderbot-model.bin")
 | 
			
		||||
    parser.add_argument("--save_dir", default="hf_blenderbot", type=str, help="Where to save converted model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_config_json", default="blenderbot-3b-config.json", type=str, help="Path to config to use"
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_parlai_checkpoint(args.src_path, args.save_dir, args.hf_config_json)
 | 
			
		||||
@ -1,191 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# git clone https://github.com/salesforce/BLIP.git
 | 
			
		||||
from models.blip import blip_decoder
 | 
			
		||||
from models.blip_itm import blip_itm
 | 
			
		||||
from models.blip_vqa import blip_vqa
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
from torchvision.transforms.functional import InterpolationMode
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    BlipConfig,
 | 
			
		||||
    BlipForConditionalGeneration,
 | 
			
		||||
    BlipForImageTextRetrieval,
 | 
			
		||||
    BlipForQuestionAnswering,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_demo_image(image_size, device):
 | 
			
		||||
    img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
 | 
			
		||||
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
 | 
			
		||||
 | 
			
		||||
    transform = transforms.Compose(
 | 
			
		||||
        [
 | 
			
		||||
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    image = transform(raw_image).unsqueeze(0).to(device)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(key):
 | 
			
		||||
    if "visual_encoder" in key:
 | 
			
		||||
        key = re.sub("visual_encoder*", "vision_model.encoder", key)
 | 
			
		||||
    if "blocks" in key:
 | 
			
		||||
        key = re.sub(r"blocks", "layers", key)
 | 
			
		||||
    if "attn" in key:
 | 
			
		||||
        key = re.sub(r"attn", "self_attn", key)
 | 
			
		||||
    if "norm1" in key:
 | 
			
		||||
        key = re.sub(r"norm1", "layer_norm1", key)
 | 
			
		||||
    if "norm2" in key:
 | 
			
		||||
        key = re.sub(r"norm2", "layer_norm2", key)
 | 
			
		||||
    if "encoder.norm" in key:
 | 
			
		||||
        key = re.sub(r"encoder.norm", "post_layernorm", key)
 | 
			
		||||
    if "encoder.patch_embed.proj" in key:
 | 
			
		||||
        key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key)
 | 
			
		||||
 | 
			
		||||
    if "encoder.pos_embed" in key:
 | 
			
		||||
        key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key)
 | 
			
		||||
    if "encoder.cls_token" in key:
 | 
			
		||||
        key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key)
 | 
			
		||||
 | 
			
		||||
    if "self_attn" in key:
 | 
			
		||||
        key = re.sub(r"self_attn.proj", "self_attn.projection", key)
 | 
			
		||||
 | 
			
		||||
    return key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if config_path is not None:
 | 
			
		||||
        config = BlipConfig.from_pretrained(config_path)
 | 
			
		||||
    else:
 | 
			
		||||
        config = BlipConfig(projection_dim=512, text_config={}, vision_config={})
 | 
			
		||||
 | 
			
		||||
    hf_model = BlipForConditionalGeneration(config).eval()
 | 
			
		||||
 | 
			
		||||
    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth"
 | 
			
		||||
 | 
			
		||||
    pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base")
 | 
			
		||||
    pt_model = pt_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = pt_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_model.load_state_dict(modified_state_dict)
 | 
			
		||||
 | 
			
		||||
    image_size = 384
 | 
			
		||||
    image = load_demo_image(image_size=image_size, device="cpu")
 | 
			
		||||
    tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
 | 
			
		||||
    input_ids = tokenizer(["a picture of"]).input_ids
 | 
			
		||||
 | 
			
		||||
    out = hf_model.generate(image, input_ids)
 | 
			
		||||
 | 
			
		||||
    assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
 | 
			
		||||
 | 
			
		||||
    out = hf_model.generate(image)
 | 
			
		||||
 | 
			
		||||
    assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth'
 | 
			
		||||
    model_url = (
 | 
			
		||||
        "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base")
 | 
			
		||||
    vqa_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = vqa_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_vqa_model = BlipForQuestionAnswering(config)
 | 
			
		||||
 | 
			
		||||
    hf_vqa_model.load_state_dict(modified_state_dict)
 | 
			
		||||
 | 
			
		||||
    question = ["How many dogs are in this image?"]
 | 
			
		||||
    question_input_ids = tokenizer(question, return_tensors="pt").input_ids
 | 
			
		||||
 | 
			
		||||
    answer = hf_vqa_model.generate(question_input_ids, image)
 | 
			
		||||
    print(tokenizer.decode(answer[0]))
 | 
			
		||||
 | 
			
		||||
    assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]"
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa")
 | 
			
		||||
 | 
			
		||||
    model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth"
 | 
			
		||||
 | 
			
		||||
    itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base")
 | 
			
		||||
    itm_model.eval()
 | 
			
		||||
 | 
			
		||||
    modified_state_dict = itm_model.state_dict()
 | 
			
		||||
    for key in modified_state_dict.copy():
 | 
			
		||||
        value = modified_state_dict.pop(key)
 | 
			
		||||
        renamed_key = rename_key(key)
 | 
			
		||||
        modified_state_dict[renamed_key] = value
 | 
			
		||||
 | 
			
		||||
    hf_itm_model = BlipForImageTextRetrieval(config)
 | 
			
		||||
 | 
			
		||||
    question = ["A picture of a woman with a dog sitting in a beach"]
 | 
			
		||||
    question_input_ids = tokenizer(
 | 
			
		||||
        question,
 | 
			
		||||
        return_tensors="pt",
 | 
			
		||||
        padding="max_length",
 | 
			
		||||
        truncation=True,
 | 
			
		||||
        max_length=35,
 | 
			
		||||
    ).input_ids
 | 
			
		||||
 | 
			
		||||
    hf_itm_model.load_state_dict(modified_state_dict)
 | 
			
		||||
    hf_itm_model.eval()
 | 
			
		||||
 | 
			
		||||
    out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True)
 | 
			
		||||
    out = hf_itm_model(question_input_ids, image, use_itm_head=False)
 | 
			
		||||
 | 
			
		||||
    assert out[0].item() == 0.2110687494277954
 | 
			
		||||
    assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
@ -1,390 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
Convert BLIP-2 checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/blip2
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
# pip3 install salesforce-lavis
 | 
			
		||||
# I'm actually installing a slightly modified version: pip3 install -U git+https://github.com/nielsrogge/LAVIS.git@blip2_float32
 | 
			
		||||
# to make sure we can compare both original and HF implementation in float32
 | 
			
		||||
from lavis.models import load_model_and_preprocess
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
    BertTokenizer,
 | 
			
		||||
    Blip2Config,
 | 
			
		||||
    Blip2ForConditionalGeneration,
 | 
			
		||||
    Blip2ForImageTextRetrieval,
 | 
			
		||||
    Blip2Processor,
 | 
			
		||||
    Blip2QFormerConfig,
 | 
			
		||||
    Blip2VisionConfig,
 | 
			
		||||
    BlipImageProcessor,
 | 
			
		||||
    OPTConfig,
 | 
			
		||||
    T5Config,
 | 
			
		||||
    set_seed,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_demo_image():
 | 
			
		||||
    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
 | 
			
		||||
    image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
 | 
			
		||||
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config, model_name):
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    # fmt: off
 | 
			
		||||
 | 
			
		||||
    # vision encoder
 | 
			
		||||
    rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
 | 
			
		||||
    rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
 | 
			
		||||
    rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
 | 
			
		||||
    rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
 | 
			
		||||
 | 
			
		||||
    for i in range(config.vision_config.num_hidden_layers):
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
 | 
			
		||||
 | 
			
		||||
    # QFormer
 | 
			
		||||
    rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.layernorm.weight"))
 | 
			
		||||
    rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.layernorm.bias"))
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        rename_keys.append(("Qformer.bert.embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"))
 | 
			
		||||
        rename_keys.append(("Qformer.bert.embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"))
 | 
			
		||||
        rename_keys.append(("vision_proj.weight", "vision_projection.weight"))
 | 
			
		||||
        rename_keys.append(("vision_proj.bias", "vision_projection.bias"))
 | 
			
		||||
        rename_keys.append(("text_proj.weight", "text_projection.weight"))
 | 
			
		||||
        rename_keys.append(("text_proj.bias", "text_projection.bias"))
 | 
			
		||||
 | 
			
		||||
    # fmt: on
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_v_bias(state_dict, config):
 | 
			
		||||
    for i in range(config.vision_config.num_hidden_layers):
 | 
			
		||||
        # read in original q and v biases
 | 
			
		||||
        q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
 | 
			
		||||
        v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
 | 
			
		||||
 | 
			
		||||
        # next, set bias in the state dict
 | 
			
		||||
        qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
 | 
			
		||||
        state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_blip2_config(model_name, eos_token_id):
 | 
			
		||||
    image_size = 364 if "coco" in model_name else 224
 | 
			
		||||
    vision_config = Blip2VisionConfig(image_size=image_size).to_dict()
 | 
			
		||||
 | 
			
		||||
    # make sure the models have proper bos_token_id and eos_token_id set (important for generation)
 | 
			
		||||
    # seems like flan-T5 models don't have bos_token_id properly set?
 | 
			
		||||
    if "opt-2.7b" in model_name:
 | 
			
		||||
        text_config = OPTConfig.from_pretrained("facebook/opt-2.7b", eos_token_id=eos_token_id).to_dict()
 | 
			
		||||
    elif "opt-6.7b" in model_name:
 | 
			
		||||
        text_config = OPTConfig.from_pretrained("facebook/opt-6.7b", eos_token_id=eos_token_id).to_dict()
 | 
			
		||||
    elif "t5-xl" in model_name:
 | 
			
		||||
        text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
 | 
			
		||||
    elif "t5-xxl" in model_name:
 | 
			
		||||
        text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
 | 
			
		||||
    elif "itm" in model_name:
 | 
			
		||||
        text_config = {}
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Model name not supported")
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        config = Blip2Config(
 | 
			
		||||
            vision_config=vision_config,
 | 
			
		||||
            qformer_config=Blip2QFormerConfig(vocab_size=30523, use_qformer_text_input=True).to_dict(),
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        config = Blip2Config(vision_config=vision_config, text_config=text_config)
 | 
			
		||||
 | 
			
		||||
    return config, image_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_blip2_checkpoint(
 | 
			
		||||
    model_name, pytorch_dump_folder_path=None, push_to_hub=False, lavis_device="cpu", hf_model_device="cpu"
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to Transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if "opt" in model_name:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
 | 
			
		||||
    elif "itm" in model_name:
 | 
			
		||||
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
 | 
			
		||||
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
 | 
			
		||||
    else:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        eos_token_id = None
 | 
			
		||||
    else:
 | 
			
		||||
        eos_token_id = tokenizer("\n", add_special_tokens=False).input_ids[0]
 | 
			
		||||
    config, image_size = get_blip2_config(model_name, eos_token_id=eos_token_id)
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        hf_model = Blip2ForImageTextRetrieval(config).eval()
 | 
			
		||||
    else:
 | 
			
		||||
        hf_model = Blip2ForConditionalGeneration(config).eval()
 | 
			
		||||
 | 
			
		||||
    model_name_to_original = {
 | 
			
		||||
        "blip2-opt-2.7b": ("blip2_opt", "pretrain_opt2.7b"),
 | 
			
		||||
        "blip2-opt-6.7b": ("blip2_opt", "pretrain_opt6.7b"),
 | 
			
		||||
        "blip2-opt-2.7b-coco": ("blip2_opt", "caption_coco_opt2.7b"),
 | 
			
		||||
        "blip2-opt-6.7b-coco": ("blip2_opt", "caption_coco_opt6.7b"),
 | 
			
		||||
        "blip2-flan-t5-xl": ("blip2_t5", "pretrain_flant5xl"),
 | 
			
		||||
        "blip2-flan-t5-xl-coco": ("blip2_t5", "caption_coco_flant5xl"),
 | 
			
		||||
        "blip2-flan-t5-xxl": ("blip2_t5", "pretrain_flant5xxl"),
 | 
			
		||||
        "blip2-itm-vit-g": ("blip2_image_text_matching", "pretrain"),
 | 
			
		||||
        "blip2-itm-vit-g-coco": ("blip2_image_text_matching", "coco"),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    name, type = model_name_to_original[model_name]
 | 
			
		||||
 | 
			
		||||
    # load original model
 | 
			
		||||
    print("Loading original model...")
 | 
			
		||||
    original_model, vis_processors, _ = load_model_and_preprocess(
 | 
			
		||||
        name=name, model_type=type, is_eval=True, device=lavis_device
 | 
			
		||||
    )
 | 
			
		||||
    original_model.eval()
 | 
			
		||||
    print("Done!")
 | 
			
		||||
 | 
			
		||||
    # update state dict keys
 | 
			
		||||
    state_dict = original_model.state_dict()
 | 
			
		||||
    rename_keys = create_rename_keys(config, model_name)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
 | 
			
		||||
    # some keys can be renamed efficiently
 | 
			
		||||
    for key, val in state_dict.copy().items():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if key.startswith("Qformer.bert"):
 | 
			
		||||
            key = key.replace("Qformer.bert", "qformer")
 | 
			
		||||
        if "attention.self" in key:
 | 
			
		||||
            key = key.replace("self", "attention")
 | 
			
		||||
        if "opt_proj" in key:
 | 
			
		||||
            key = key.replace("opt_proj", "language_projection")
 | 
			
		||||
        if "t5_proj" in key:
 | 
			
		||||
            key = key.replace("t5_proj", "language_projection")
 | 
			
		||||
        if key.startswith("opt"):
 | 
			
		||||
            key = key.replace("opt", "language")
 | 
			
		||||
        if key.startswith("t5"):
 | 
			
		||||
            key = key.replace("t5", "language")
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # read in qv biases
 | 
			
		||||
    read_in_q_v_bias(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
    assert len(missing_keys) == 0
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        unexpected_keys = list(filter(lambda x: not x.startswith("Qformer.cls"), unexpected_keys))
 | 
			
		||||
        assert unexpected_keys == ["temp", "qformer.embeddings.position_ids"]
 | 
			
		||||
    else:
 | 
			
		||||
        assert unexpected_keys == ["qformer.embeddings.position_ids"]
 | 
			
		||||
 | 
			
		||||
    image = load_demo_image()
 | 
			
		||||
    original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
 | 
			
		||||
 | 
			
		||||
    # create processor
 | 
			
		||||
    image_processor = BlipImageProcessor(
 | 
			
		||||
        size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
 | 
			
		||||
    )
 | 
			
		||||
    processor = Blip2Processor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
    # make sure processor creates exact same pixel values
 | 
			
		||||
    assert torch.allclose(pixel_values, original_pixel_values.to(pixel_values.device))
 | 
			
		||||
 | 
			
		||||
    original_model.to(lavis_device)
 | 
			
		||||
    hf_model.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
    if "itm" in model_name:
 | 
			
		||||
        caption = "a large fountain spewing water into the air"
 | 
			
		||||
        input_ids = tokenizer([caption], return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
        attention_mask = processor(text=caption, return_tensors="pt").attention_mask.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            original_logits = original_model(
 | 
			
		||||
                {"image": original_pixel_values, "text_input": [caption]}, match_head="itm"
 | 
			
		||||
            )
 | 
			
		||||
            logits = hf_model(
 | 
			
		||||
                pixel_values=pixel_values,
 | 
			
		||||
                input_ids=input_ids,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                use_image_text_matching_head=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.logits_per_image.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3])
 | 
			
		||||
        print("First values of HF logits:", logits.logits_per_image[0, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        # cast to same type
 | 
			
		||||
        target_dtype = logits.logits_per_image.dtype
 | 
			
		||||
        assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
        original_itm_scores = torch.nn.functional.softmax(original_logits, dim=1)
 | 
			
		||||
        itm_scores = torch.nn.functional.softmax(logits.logits_per_image, dim=1)
 | 
			
		||||
        assert torch.allclose(original_itm_scores.to(target_dtype), itm_scores, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            original_logits = original_model(
 | 
			
		||||
                {"image": original_pixel_values, "text_input": [caption]}, match_head="itc"
 | 
			
		||||
            )
 | 
			
		||||
            logits = hf_model(
 | 
			
		||||
                pixel_values=pixel_values,
 | 
			
		||||
                input_ids=input_ids,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                use_image_text_matching_head=False,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.logits_per_image.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3])
 | 
			
		||||
        print("First values of HF logits:", logits.logits_per_image[0, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        # cast to same type
 | 
			
		||||
        target_dtype = logits.logits_per_image.dtype
 | 
			
		||||
        assert torch.allclose(original_logits.to(target_dtype), logits.logits_per_image, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        input_ids = tokenizer(["\n"], return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            if "opt" in model_name:
 | 
			
		||||
                original_logits = original_model({"image": original_pixel_values, "text_input": [""]}).logits
 | 
			
		||||
                logits = hf_model(pixel_values, input_ids).logits
 | 
			
		||||
            else:
 | 
			
		||||
                original_logits = original_model(
 | 
			
		||||
                    {"image": original_pixel_values, "text_input": ["\n"], "text_output": ["\n"]}
 | 
			
		||||
                ).logits
 | 
			
		||||
                labels = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
 | 
			
		||||
                logits = hf_model(pixel_values, input_ids, labels=labels).logits
 | 
			
		||||
 | 
			
		||||
        assert original_logits.shape == logits.shape
 | 
			
		||||
        print("First values of original logits:", original_logits[0, :3, :3])
 | 
			
		||||
        print("First values of HF logits:", logits[0, :3, :3])
 | 
			
		||||
 | 
			
		||||
        # assert values
 | 
			
		||||
        assert torch.allclose(original_logits.to(logits.device), logits, atol=1e-4)
 | 
			
		||||
        print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
        print("Generating a caption...")
 | 
			
		||||
        prompt = "Question: what object is in this image? Answer:"
 | 
			
		||||
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(hf_model_device)
 | 
			
		||||
 | 
			
		||||
        set_seed(42)
 | 
			
		||||
 | 
			
		||||
        original_outputs = original_model.generate(
 | 
			
		||||
            {"image": original_pixel_values, "prompt": prompt}, use_nucleus_sampling=True, max_length=50
 | 
			
		||||
        )
 | 
			
		||||
        outputs = hf_model.generate(
 | 
			
		||||
            pixel_values,
 | 
			
		||||
            input_ids,
 | 
			
		||||
            do_sample=True,
 | 
			
		||||
            num_beams=5,
 | 
			
		||||
            max_length=30,
 | 
			
		||||
            min_length=1,
 | 
			
		||||
            top_p=0.9,
 | 
			
		||||
            repetition_penalty=1.0,
 | 
			
		||||
            length_penalty=1.0,
 | 
			
		||||
            temperature=1,
 | 
			
		||||
        )
 | 
			
		||||
        output_text = processor.batch_decode(outputs, skip_special_tokens=True)
 | 
			
		||||
        output_text = [text.strip() for text in output_text]
 | 
			
		||||
        print("Original generation:", original_outputs)
 | 
			
		||||
        print("HF generation:", output_text)
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        processor.push_to_hub(f"nielsr/{model_name}")
 | 
			
		||||
        hf_model.push_to_hub(f"nielsr/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    choices = [
 | 
			
		||||
        "blip2-opt-2.7b",
 | 
			
		||||
        "blip2-opt-6.7b",
 | 
			
		||||
        "blip2-opt-2.7b-coco",
 | 
			
		||||
        "blip2-opt-6.7b-coco",
 | 
			
		||||
        "blip2-flan-t5-xl",
 | 
			
		||||
        "blip2-flan-t5-xl-coco",
 | 
			
		||||
        "blip2-flan-t5-xxl",
 | 
			
		||||
        "blip2-itm-vit-g",
 | 
			
		||||
        "blip2-itm-vit-g-coco",
 | 
			
		||||
    ]
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="blip2-opt-2.7b",
 | 
			
		||||
        choices=choices,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to hf config.json of model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to push the model and processor to the hub after converting",
 | 
			
		||||
    )
 | 
			
		||||
    # note: this script is tested on 2 GPUs, as models are compared in float32,
 | 
			
		||||
    # which requires quite some memory. Hence loading both on a
 | 
			
		||||
    # separate device is the easiest to compare
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--lavis_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--hf_model_device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_blip2_checkpoint(
 | 
			
		||||
        args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.lavis_device, args.hf_model_device
 | 
			
		||||
    )
 | 
			
		||||
@ -1,254 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert BigScience BLOOM checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BloomConfig, BloomModel
 | 
			
		||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
 | 
			
		||||
    "word_embeddings_layernorm.weight",
 | 
			
		||||
    "word_embeddings_layernorm.bias",
 | 
			
		||||
    "input_layernorm.weight",
 | 
			
		||||
    "input_layernorm.bias",
 | 
			
		||||
    "post_attention_layernorm.weight",
 | 
			
		||||
    "post_attention_layernorm.bias",
 | 
			
		||||
    "self_attention.dense.bias",
 | 
			
		||||
    "mlp.dense_4h_to_h.bias",
 | 
			
		||||
    "ln_f.weight",
 | 
			
		||||
    "ln_f.bias",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
 | 
			
		||||
    "mlp.dense_4h_to_h.weight",
 | 
			
		||||
    "self_attention.dense.weight",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def layer_name_mapping(key, file):
 | 
			
		||||
    """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
 | 
			
		||||
    # Handle first and last layers
 | 
			
		||||
    layer_rename_map = {
 | 
			
		||||
        "word_embeddings.weight": "word_embeddings.weight",
 | 
			
		||||
        "word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
 | 
			
		||||
        "word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
 | 
			
		||||
        "weight": "ln_f.weight",
 | 
			
		||||
        "bias": "ln_f.bias",
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if key in layer_rename_map:
 | 
			
		||||
        return layer_rename_map[key]
 | 
			
		||||
 | 
			
		||||
    # Handle transformer blocks
 | 
			
		||||
    layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
 | 
			
		||||
    layer_number -= 3
 | 
			
		||||
    return f"h.{layer_number}." + key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dtype_size(dtype):
 | 
			
		||||
    if dtype == torch.bool:
 | 
			
		||||
        return 1 / 8
 | 
			
		||||
    bit_search = re.search(r"[^\d](\d+)$", str(dtype))
 | 
			
		||||
    if bit_search is None:
 | 
			
		||||
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
 | 
			
		||||
    bit_size = int(bit_search.groups()[0])
 | 
			
		||||
    return bit_size // 8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bloom_checkpoint_to_pytorch(
 | 
			
		||||
    bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
 | 
			
		||||
):
 | 
			
		||||
    # Construct model
 | 
			
		||||
    if bloom_config_file == "":
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
    else:
 | 
			
		||||
        config = BloomConfig.from_json_file(bloom_config_file)
 | 
			
		||||
 | 
			
		||||
    if shard_model:
 | 
			
		||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
			
		||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
			
		||||
 | 
			
		||||
        index_dict = {"weight_map": {}, "metadata": {}}
 | 
			
		||||
        total_size = 0
 | 
			
		||||
 | 
			
		||||
        missing_keys = None
 | 
			
		||||
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
 | 
			
		||||
        for j, file in enumerate(file_names):
 | 
			
		||||
            print("Processing file: {}".format(file))
 | 
			
		||||
            tensors = None
 | 
			
		||||
 | 
			
		||||
            for i in range(pretraining_tp):
 | 
			
		||||
                # load all TP files
 | 
			
		||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
			
		||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
 | 
			
		||||
 | 
			
		||||
                # Rename keys in the transformers names
 | 
			
		||||
                keys = list(temp.keys())
 | 
			
		||||
                for key in keys:
 | 
			
		||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
			
		||||
 | 
			
		||||
                if tensors is None:
 | 
			
		||||
                    tensors = temp
 | 
			
		||||
                else:
 | 
			
		||||
                    for key in tensors.keys():
 | 
			
		||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                            # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
			
		||||
                            tensors[key] += temp[key]
 | 
			
		||||
                        else:
 | 
			
		||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
			
		||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
			
		||||
                            # We concatenate these weights accross TP ranks
 | 
			
		||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
			
		||||
 | 
			
		||||
            # Divide by the number of TP the weights we want to average
 | 
			
		||||
            for key in tensors.keys():
 | 
			
		||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
			
		||||
            torch.save(
 | 
			
		||||
                tensors,
 | 
			
		||||
                os.path.join(
 | 
			
		||||
                    pytorch_dump_folder_path,
 | 
			
		||||
                    "pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            for key in tensors.keys():
 | 
			
		||||
                value = tensors[key]
 | 
			
		||||
                total_size += value.numel() * get_dtype_size(value.dtype)
 | 
			
		||||
                if key not in index_dict["weight_map"]:
 | 
			
		||||
                    index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
 | 
			
		||||
                        str(j + 1).zfill(5), str(len(file_names)).zfill(5)
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        config = BloomConfig()
 | 
			
		||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
			
		||||
        index_dict["metadata"]["total_size"] = total_size
 | 
			
		||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(config.to_json_string())
 | 
			
		||||
        with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
 | 
			
		||||
            json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
 | 
			
		||||
            f.write(json_config)
 | 
			
		||||
    else:
 | 
			
		||||
        model = BloomModel(config)
 | 
			
		||||
 | 
			
		||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
			
		||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
			
		||||
 | 
			
		||||
        missing_keys = None
 | 
			
		||||
        for i, file in enumerate(file_names):
 | 
			
		||||
            tensors = None
 | 
			
		||||
            for i in range(pretraining_tp):
 | 
			
		||||
                # load all TP files
 | 
			
		||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
			
		||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
 | 
			
		||||
 | 
			
		||||
                # Rename keys in the transformers names
 | 
			
		||||
                keys = list(temp.keys())
 | 
			
		||||
                for key in keys:
 | 
			
		||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
			
		||||
 | 
			
		||||
                if tensors is None:
 | 
			
		||||
                    tensors = temp
 | 
			
		||||
                else:
 | 
			
		||||
                    for key in tensors.keys():
 | 
			
		||||
                        # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
			
		||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                            tensors[key] += temp[key]
 | 
			
		||||
                        else:
 | 
			
		||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
			
		||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
			
		||||
                            # We concatenate these weights accross TP ranks
 | 
			
		||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
			
		||||
 | 
			
		||||
            # Divide by the number of TP the weights we want to average
 | 
			
		||||
            for key in tensors.keys():
 | 
			
		||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
			
		||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
			
		||||
 | 
			
		||||
            other_keys = model.load_state_dict(tensors, strict=False)
 | 
			
		||||
            assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
 | 
			
		||||
            if missing_keys is None:
 | 
			
		||||
                missing_keys = set(other_keys.missing_keys)
 | 
			
		||||
            else:
 | 
			
		||||
                missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
 | 
			
		||||
 | 
			
		||||
        assert not missing_keys, f"The keys {missing_keys} are missing"
 | 
			
		||||
 | 
			
		||||
        # Save pytorch-model
 | 
			
		||||
        os.makedirs(pytorch_dump_folder_path, exist_ok=True)
 | 
			
		||||
        pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
 | 
			
		||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
			
		||||
        print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
 | 
			
		||||
        if config.torch_dtype is not None:
 | 
			
		||||
            model = model.to(config.torch_dtype)
 | 
			
		||||
        torch.save(model.state_dict(), pytorch_weights_dump_path)
 | 
			
		||||
        print(f"Save configuration file to {pytorch_config_dump_path}")
 | 
			
		||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(config.to_json_string())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bloom_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the Megatron-LM checkpoint path.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bloom_config_file",
 | 
			
		||||
        default="",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help=(
 | 
			
		||||
            "An optional config json file corresponding to the pre-trained model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--shard_model",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pretraining_tp",
 | 
			
		||||
        default=4,
 | 
			
		||||
        type=int,
 | 
			
		||||
        help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bloom_checkpoint_to_pytorch(
 | 
			
		||||
        args.bloom_checkpoint_path,
 | 
			
		||||
        args.bloom_config_file,
 | 
			
		||||
        args.pytorch_dump_folder_path,
 | 
			
		||||
        args.shard_model,
 | 
			
		||||
        args.pretraining_tp,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,145 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Bros checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import bros  # original repo
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import BrosConfig, BrosModel, BrosProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_configs(model_name):
 | 
			
		||||
    bros_config = BrosConfig.from_pretrained(model_name)
 | 
			
		||||
    return bros_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_ignore_keys_(state_dict):
 | 
			
		||||
    ignore_keys = [
 | 
			
		||||
        "embeddings.bbox_sinusoid_emb.inv_freq",
 | 
			
		||||
    ]
 | 
			
		||||
    for k in ignore_keys:
 | 
			
		||||
        state_dict.pop(k, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if name == "embeddings.bbox_projection.weight":
 | 
			
		||||
        name = "bbox_embeddings.bbox_projection.weight"
 | 
			
		||||
 | 
			
		||||
    if name == "embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq":
 | 
			
		||||
        name = "bbox_embeddings.bbox_sinusoid_emb.x_pos_emb.inv_freq"
 | 
			
		||||
 | 
			
		||||
    if name == "embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq":
 | 
			
		||||
        name = "bbox_embeddings.bbox_sinusoid_emb.y_pos_emb.inv_freq"
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, model):
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in orig_state_dict.copy().keys():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
        orig_state_dict[rename_key(key)] = val
 | 
			
		||||
 | 
			
		||||
    # remove ignore keys
 | 
			
		||||
    remove_ignore_keys_(orig_state_dict)
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bros_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
 | 
			
		||||
    # load original model
 | 
			
		||||
    original_model = bros.BrosModel.from_pretrained(model_name).eval()
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace Model
 | 
			
		||||
    bros_config = get_configs(model_name)
 | 
			
		||||
    model = BrosModel.from_pretrained(model_name, config=bros_config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    state_dict = original_model.state_dict()
 | 
			
		||||
    new_state_dict = convert_state_dict(state_dict, model)
 | 
			
		||||
    model.load_state_dict(new_state_dict)
 | 
			
		||||
 | 
			
		||||
    # verify results
 | 
			
		||||
 | 
			
		||||
    # original BROS model require 4 points (8 float values) for each bbox, prepare bbox with [batch_size, seq_len, 8] shape
 | 
			
		||||
    bbox = torch.tensor(
 | 
			
		||||
        [
 | 
			
		||||
            [
 | 
			
		||||
                [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
 | 
			
		||||
                [0.4396, 0.6720, 0.4659, 0.6720, 0.4659, 0.6850, 0.4396, 0.6850],
 | 
			
		||||
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
 | 
			
		||||
                [0.4698, 0.6720, 0.4843, 0.6720, 0.4843, 0.6850, 0.4698, 0.6850],
 | 
			
		||||
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
 | 
			
		||||
                [0.2047, 0.6870, 0.2730, 0.6870, 0.2730, 0.7000, 0.2047, 0.7000],
 | 
			
		||||
                [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
 | 
			
		||||
            ]
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    processor = BrosProcessor.from_pretrained(model_name)
 | 
			
		||||
 | 
			
		||||
    encoding = processor("His name is Rocco.", return_tensors="pt")
 | 
			
		||||
    encoding["bbox"] = bbox
 | 
			
		||||
 | 
			
		||||
    original_hidden_states = original_model(**encoding).last_hidden_state
 | 
			
		||||
    # pixel_values = processor(image, return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    last_hidden_states = model(**encoding).last_hidden_state
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(original_hidden_states, last_hidden_states, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        print(f"Saving model and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
 | 
			
		||||
        processor.push_to_hub("jinho8345/" + model_name.split("/")[-1], commit_message="Update model")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="jinho8345/bros-base-uncased",
 | 
			
		||||
        required=False,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the original model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        required=False,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether or not to push the converted model and processor to the 🤗 hub.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bros_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,59 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert T5 checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
 | 
			
		||||
    # Initialise PyTorch model
 | 
			
		||||
    config = T5Config.from_json_file(config_file)
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
    model = T5ForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained T5 model. \nThis specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,65 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert CANINE checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import CanineConfig, CanineModel, CanineTokenizer, load_tf_weights_in_canine
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_path):
 | 
			
		||||
    # Initialize PyTorch model
 | 
			
		||||
    config = CanineConfig()
 | 
			
		||||
    model = CanineModel(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    print(f"Building PyTorch model from configuration: {config}")
 | 
			
		||||
 | 
			
		||||
    # Load weights from tf checkpoint
 | 
			
		||||
    load_tf_weights_in_canine(model, config, tf_checkpoint_path)
 | 
			
		||||
 | 
			
		||||
    # Save pytorch-model (weights and configuration)
 | 
			
		||||
    print(f"Save PyTorch model to {pytorch_dump_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    # Save tokenizer files
 | 
			
		||||
    tokenizer = CanineTokenizer()
 | 
			
		||||
    print(f"Save tokenizer files to {pytorch_dump_path}")
 | 
			
		||||
    tokenizer.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the TensorFlow checkpoint. Should end with model.ckpt",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to a folder where the PyTorch model will be placed.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.pytorch_dump_path)
 | 
			
		||||
@ -1,476 +0,0 @@
 | 
			
		||||
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import argparse
 | 
			
		||||
import gc
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
import yaml
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    ChameleonConfig,
 | 
			
		||||
    ChameleonForConditionalGeneration,
 | 
			
		||||
    ChameleonImageProcessor,
 | 
			
		||||
    ChameleonProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers import LlamaTokenizerFast
 | 
			
		||||
except ImportError:
 | 
			
		||||
    raise ValueError(
 | 
			
		||||
        "Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! "
 | 
			
		||||
        "Update your `tokenizers` library and re-run the tokenizer conversion."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Sample usage:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \
 | 
			
		||||
    --input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Thereafter, models can be loaded via:
 | 
			
		||||
 | 
			
		||||
```py
 | 
			
		||||
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
 | 
			
		||||
 | 
			
		||||
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
 | 
			
		||||
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
 | 
			
		||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
NUM_SHARDS = {
 | 
			
		||||
    "7B": 1,
 | 
			
		||||
    "30B": 4,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
VOCAB_SIZE = 65536
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
 | 
			
		||||
    return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_json(path):
 | 
			
		||||
    with open(path, "r") as f:
 | 
			
		||||
        return json.load(f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_json(text, path):
 | 
			
		||||
    with open(path, "w") as f:
 | 
			
		||||
        json.dump(text, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_model(model_path, input_base_path, model_size, chameleon_version=1):
 | 
			
		||||
    os.makedirs(model_path, exist_ok=True)
 | 
			
		||||
    input_model_path = os.path.join(input_base_path, "models", model_size.lower())
 | 
			
		||||
    params_path = os.path.join(input_model_path, "params.json")
 | 
			
		||||
    consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json")
 | 
			
		||||
 | 
			
		||||
    params = read_json(params_path)
 | 
			
		||||
    if os.path.isfile(consolidate_params_path):
 | 
			
		||||
        params = {**params, **read_json(consolidate_params_path)}
 | 
			
		||||
    num_shards = NUM_SHARDS[model_size]
 | 
			
		||||
    model_parallel_size = params["model_parallel_size"]
 | 
			
		||||
    params = params.get("model", params)
 | 
			
		||||
    n_layers = params["n_layers"]
 | 
			
		||||
    n_heads = params["n_heads"]
 | 
			
		||||
    n_heads_per_shard = n_heads // num_shards
 | 
			
		||||
    dim = params["dim"]
 | 
			
		||||
    dims_per_head = dim // n_heads
 | 
			
		||||
    base = params.get("rope_theta", 10000.0)
 | 
			
		||||
    swin_norm = params["swin_norm"]
 | 
			
		||||
    if base > 10000.0:
 | 
			
		||||
        max_position_embeddings = 16384
 | 
			
		||||
    else:
 | 
			
		||||
        # Depending on the Chameleon version, the default max_position_embeddings has different values.
 | 
			
		||||
        if chameleon_version == 1:
 | 
			
		||||
            max_position_embeddings = 4096
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                f"Version {chameleon_version} of chameleon is not supported yet. "
 | 
			
		||||
                "Current supported versions of chameleon are [1]."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if params.get("n_kv_heads", None) is not None:
 | 
			
		||||
        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
 | 
			
		||||
        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
 | 
			
		||||
        key_value_dim = dim // num_key_value_heads
 | 
			
		||||
    else:  # compatibility with other checkpoints
 | 
			
		||||
        num_key_value_heads = n_heads
 | 
			
		||||
        num_local_key_value_heads = n_heads_per_shard
 | 
			
		||||
        key_value_dim = dim
 | 
			
		||||
 | 
			
		||||
    print(f"Fetching all parameters from the checkpoint at {input_model_path}.")
 | 
			
		||||
    # Load weights
 | 
			
		||||
    if num_shards == 1:
 | 
			
		||||
        # Not sharded
 | 
			
		||||
        # (The sharded implementation would also work, but this is simpler.)
 | 
			
		||||
        loaded = None
 | 
			
		||||
        for possible_name in ["consolidated.pth", "consolidated.00.pth"]:
 | 
			
		||||
            possible_path = os.path.join(input_model_path, possible_name)
 | 
			
		||||
            if os.path.exists(possible_path):
 | 
			
		||||
                loaded = torch.load(possible_path, map_location="cpu")
 | 
			
		||||
                break
 | 
			
		||||
        assert loaded is not None
 | 
			
		||||
    else:
 | 
			
		||||
        # Sharded
 | 
			
		||||
        loaded = [
 | 
			
		||||
            torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
 | 
			
		||||
            for i in range(num_shards)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    # permute for sliced rotary
 | 
			
		||||
    def permute(w, n_heads, dim1=dim, dim2=dim):
 | 
			
		||||
        return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
 | 
			
		||||
 | 
			
		||||
    # Load weights to the state dict
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
    for layer_i in range(n_layers):
 | 
			
		||||
        if num_shards == 1:
 | 
			
		||||
            # Unsharded
 | 
			
		||||
            state_dict.update(
 | 
			
		||||
                {
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
 | 
			
		||||
                        loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
 | 
			
		||||
                    ),
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
 | 
			
		||||
                        loaded[f"layers.{layer_i}.attention.wk.weight"],
 | 
			
		||||
                        n_heads=num_key_value_heads,
 | 
			
		||||
                        dim1=key_value_dim,
 | 
			
		||||
                    ),
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
 | 
			
		||||
                    f"model.layers.{layer_i}.input_layernorm.weight": loaded[
 | 
			
		||||
                        f"layers.{layer_i}.attention_norm.weight"
 | 
			
		||||
                    ],
 | 
			
		||||
                    f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
 | 
			
		||||
                        f"layers.{layer_i}.ffn_norm.weight"
 | 
			
		||||
                    ],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.q_normalization.weight"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(n_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.q_normalization.bias"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(n_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.k_normalization.weight"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
 | 
			
		||||
                loaded[f"layers.{layer_i}.attention.k_normalization.bias"]
 | 
			
		||||
                .view(dims_per_head // 2, 2)
 | 
			
		||||
                .t()
 | 
			
		||||
                .reshape(1, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads, 0)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            # Sharded
 | 
			
		||||
            state_dict.update(
 | 
			
		||||
                {
 | 
			
		||||
                    f"model.layers.{layer_i}.input_layernorm.weight": torch.stack(
 | 
			
		||||
                        [l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded]
 | 
			
		||||
                    ).mean(dim=0),
 | 
			
		||||
                    f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack(
 | 
			
		||||
                        [l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded]
 | 
			
		||||
                    ).mean(dim=0),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
 | 
			
		||||
                torch.cat(
 | 
			
		||||
                    [
 | 
			
		||||
                        loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
 | 
			
		||||
                        for i in range(num_shards)
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                ).reshape(dim, dim),
 | 
			
		||||
                n_heads=n_heads,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
 | 
			
		||||
                torch.cat(
 | 
			
		||||
                    [
 | 
			
		||||
                        loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
 | 
			
		||||
                            num_local_key_value_heads, dims_per_head, dim
 | 
			
		||||
                        )
 | 
			
		||||
                        for i in range(num_shards)
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                ).reshape(key_value_dim, dim),
 | 
			
		||||
                n_heads=num_key_value_heads,
 | 
			
		||||
                dim1=key_value_dim,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677)
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(n_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(n_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = (
 | 
			
		||||
                torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded])
 | 
			
		||||
                .view(num_shards, dims_per_head // 2, 2)
 | 
			
		||||
                .transpose(1, 2)
 | 
			
		||||
                .reshape(num_shards, -1)
 | 
			
		||||
                .repeat_interleave(num_key_value_heads // num_shards, 0)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
 | 
			
		||||
                        num_local_key_value_heads, dims_per_head, dim
 | 
			
		||||
                    )
 | 
			
		||||
                    for i in range(num_shards)
 | 
			
		||||
                ],
 | 
			
		||||
                dim=0,
 | 
			
		||||
            ).reshape(key_value_dim, dim)
 | 
			
		||||
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
            )
 | 
			
		||||
            state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
 | 
			
		||||
                [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if num_shards == 1:
 | 
			
		||||
        # Unsharded
 | 
			
		||||
        state_dict.update(
 | 
			
		||||
            {
 | 
			
		||||
                "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
 | 
			
		||||
                "model.norm.weight": loaded["norm.weight"],
 | 
			
		||||
                "lm_head.weight": loaded["output.weight"],
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        state_dict.update(
 | 
			
		||||
            {
 | 
			
		||||
                "model.embed_tokens.weight": torch.cat(
 | 
			
		||||
                    [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
 | 
			
		||||
                ),
 | 
			
		||||
                "model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0),
 | 
			
		||||
                "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Load VQGAN weights
 | 
			
		||||
    vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt")
 | 
			
		||||
    vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"]
 | 
			
		||||
    for k, v in vqgan_state_dict.items():
 | 
			
		||||
        if "decoder" in k:
 | 
			
		||||
            continue  # we dont do image generation yet
 | 
			
		||||
        state_dict[f"model.vqmodel.{k}"] = v
 | 
			
		||||
 | 
			
		||||
    # Write configs
 | 
			
		||||
    ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
 | 
			
		||||
    multiple_of = params["multiple_of"] if "multiple_of" in params else 256
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file:
 | 
			
		||||
        tokenizer_config = json.load(tokenizer_file)
 | 
			
		||||
        vocabulary_map = tokenizer_config["model"]["vocab"]
 | 
			
		||||
        vocabulary_map["<image>"] = vocabulary_map[
 | 
			
		||||
            "<reserved08707>"
 | 
			
		||||
        ]  # use a reserved token instead of adding a new one
 | 
			
		||||
        del vocabulary_map["<reserved08707>"]
 | 
			
		||||
 | 
			
		||||
        for token in tokenizer_config["added_tokens"]:
 | 
			
		||||
            if token["content"] == "<reserved08707>":
 | 
			
		||||
                token["content"] = "<image>"
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f:
 | 
			
		||||
        json.dump(tokenizer_config, f)  # save the new file to init tokenizer later
 | 
			
		||||
 | 
			
		||||
    vq_keys_to_replace = [
 | 
			
		||||
        ("ch", "base_channels"),
 | 
			
		||||
        ("out_ch", "out_channels"),
 | 
			
		||||
        ("n_embed", "num_embeddings"),
 | 
			
		||||
        ("ch_mult", "channel_multiplier"),
 | 
			
		||||
        ("double_z", "double_latent"),
 | 
			
		||||
        ("z_channels", "latent_channels"),
 | 
			
		||||
    ]
 | 
			
		||||
    with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file:
 | 
			
		||||
        vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"]
 | 
			
		||||
        vq_config.update(**vq_config["ddconfig"])
 | 
			
		||||
        for old, new in vq_keys_to_replace:
 | 
			
		||||
            vq_config[new] = vq_config[old]
 | 
			
		||||
        del vq_config["ddconfig"]
 | 
			
		||||
        del vq_config["ckpt_path"]
 | 
			
		||||
        del vq_config["lossconfig"]
 | 
			
		||||
 | 
			
		||||
    config = ChameleonConfig(
 | 
			
		||||
        hidden_size=dim,
 | 
			
		||||
        intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
 | 
			
		||||
        num_attention_heads=params["n_heads"],
 | 
			
		||||
        num_hidden_layers=params["n_layers"],
 | 
			
		||||
        rms_norm_eps=params["norm_eps"],
 | 
			
		||||
        num_key_value_heads=num_key_value_heads,
 | 
			
		||||
        vocab_size=VOCAB_SIZE,
 | 
			
		||||
        rope_theta=base,
 | 
			
		||||
        max_position_embeddings=max_position_embeddings,
 | 
			
		||||
        model_parallel_size=model_parallel_size,
 | 
			
		||||
        swin_norm=swin_norm,
 | 
			
		||||
        vq_config=vq_config,
 | 
			
		||||
        vocabulary_map=vocabulary_map,
 | 
			
		||||
    )
 | 
			
		||||
    with init_empty_weights():
 | 
			
		||||
        model = ChameleonForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(state_dict, assign=True, strict=False)
 | 
			
		||||
    model.save_pretrained(model_path, safe_serialization=True)
 | 
			
		||||
 | 
			
		||||
    # Load and save the processor
 | 
			
		||||
    tokenizer = LlamaTokenizerFast(
 | 
			
		||||
        tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer.sep_token_id = 8710  # assign <reserved08706> to sep so that we can append it after input text
 | 
			
		||||
    tokenizer.pad_token_id = 1  # assing <pad> to special pad_token
 | 
			
		||||
    image_processor = ChameleonImageProcessor()
 | 
			
		||||
    processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
    processor.save_pretrained(model_path)
 | 
			
		||||
 | 
			
		||||
    # Make space so we can load the model properly now.
 | 
			
		||||
    del state_dict
 | 
			
		||||
    del loaded
 | 
			
		||||
    del vqgan_state_dict
 | 
			
		||||
    gc.collect()
 | 
			
		||||
 | 
			
		||||
    # Short inference on a few examples to check if generation makes sense
 | 
			
		||||
    # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
 | 
			
		||||
    print("Loading the checkpoint in a Chameleon model...")
 | 
			
		||||
    print("*" * 100)
 | 
			
		||||
    model = ChameleonForConditionalGeneration.from_pretrained(
 | 
			
		||||
        model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
 | 
			
		||||
    )
 | 
			
		||||
    processor = ChameleonProcessor.from_pretrained(model_path)
 | 
			
		||||
 | 
			
		||||
    prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
 | 
			
		||||
    image = Image.open(
 | 
			
		||||
        requests.get(
 | 
			
		||||
            "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
 | 
			
		||||
        ).raw
 | 
			
		||||
    )
 | 
			
		||||
    inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16)
 | 
			
		||||
    length = inputs.input_ids.shape[1]
 | 
			
		||||
 | 
			
		||||
    out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
 | 
			
		||||
    generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
 | 
			
		||||
 | 
			
		||||
    print(f"Generation for single-image: {generated_text}")
 | 
			
		||||
    print("*" * 100)
 | 
			
		||||
 | 
			
		||||
    # Multi-image example
 | 
			
		||||
    prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
 | 
			
		||||
    image = Image.open(
 | 
			
		||||
        requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
 | 
			
		||||
    )
 | 
			
		||||
    image_2 = Image.open(
 | 
			
		||||
        requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
 | 
			
		||||
    length = inputs.input_ids.shape[1]
 | 
			
		||||
    out = model.generate(**inputs, max_new_tokens=50, do_sample=False)
 | 
			
		||||
    generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
 | 
			
		||||
 | 
			
		||||
    print(f"Generation for multi-image: {generated_text}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--input_dir",
 | 
			
		||||
        help="Location of Chameleon weights",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_size",
 | 
			
		||||
        choices=["7B", "30B"],
 | 
			
		||||
        help=""
 | 
			
		||||
        " models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://github.com/facebookresearch/chameleon",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        help="Location to write HF model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--test_inference",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Whether to load the model for generation to test it's converted correctly.",
 | 
			
		||||
    )
 | 
			
		||||
    # Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--chameleon_version",
 | 
			
		||||
        choices=[1],
 | 
			
		||||
        default=1,
 | 
			
		||||
        type=int,
 | 
			
		||||
        help="Version of the Chameleon model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    write_model(
 | 
			
		||||
        model_path=args.output_dir,
 | 
			
		||||
        input_base_path=args.input_dir,
 | 
			
		||||
        model_size=args.model_size,
 | 
			
		||||
        chameleon_version=args.chameleon_version,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,134 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The OFA-Sys Team Authors and The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import ChineseCLIPConfig, ChineseCLIPModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_attn_layer(hf_attn_layer, pt_weights, prefix):
 | 
			
		||||
    q_proj, k_proj, v_proj = pt_weights[f"{prefix}.in_proj_weight"].chunk(3, dim=0)
 | 
			
		||||
    q_proj_bias, k_proj_bias, v_proj_bias = pt_weights[f"{prefix}.in_proj_bias"].chunk(3, dim=0)
 | 
			
		||||
 | 
			
		||||
    out_proj_weights = pt_weights[f"{prefix}.out_proj.weight"]
 | 
			
		||||
    out_proj_bias = pt_weights[f"{prefix}.out_proj.bias"]
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.q_proj.weight.data = q_proj
 | 
			
		||||
    hf_attn_layer.q_proj.bias.data = q_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.k_proj.weight.data = k_proj
 | 
			
		||||
    hf_attn_layer.k_proj.bias.data = k_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.v_proj.weight.data = v_proj
 | 
			
		||||
    hf_attn_layer.v_proj.bias.data = v_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.out_proj.weight.data = out_proj_weights
 | 
			
		||||
    hf_attn_layer.out_proj.bias.data = out_proj_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_mlp(hf_mlp, pt_weights, prefix):
 | 
			
		||||
    copy_linear(hf_mlp.fc1, pt_weights, f"{prefix}.c_fc")
 | 
			
		||||
    copy_linear(hf_mlp.fc2, pt_weights, f"{prefix}.c_proj")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_linear(hf_linear, pt_weights, prefix):
 | 
			
		||||
    hf_linear.weight.data = pt_weights[f"{prefix}.weight"].data
 | 
			
		||||
    hf_linear.bias.data = pt_weights[f"{prefix}.bias"].data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layer(hf_layer, pt_weights, prefix):
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_layer.layer_norm1, pt_weights, f"{prefix}.ln_1")
 | 
			
		||||
    copy_linear(hf_layer.layer_norm2, pt_weights, f"{prefix}.ln_2")
 | 
			
		||||
 | 
			
		||||
    # copy MLP
 | 
			
		||||
    copy_mlp(hf_layer.mlp, pt_weights, f"{prefix}.mlp")
 | 
			
		||||
 | 
			
		||||
    # copy attn
 | 
			
		||||
    copy_attn_layer(hf_layer.self_attn, pt_weights, f"{prefix}.attn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layers(hf_layers, pt_weights, prefix):
 | 
			
		||||
    for layer_id, hf_layer in enumerate(hf_layers):
 | 
			
		||||
        copy_layer(hf_layer, pt_weights, f"{prefix}.{layer_id}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_text_model_and_projection(hf_model, pt_weights):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.text_projection.weight.data = pt_weights["text_projection"].data.T
 | 
			
		||||
 | 
			
		||||
    # copy text encoder
 | 
			
		||||
    for name, param in hf_model.text_model.named_parameters():
 | 
			
		||||
        param.data = pt_weights[f"bert.{name}"].data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_vision_model_and_projection(hf_model, pt_weights):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.visual_projection.weight.data = pt_weights["visual.proj"].data.T
 | 
			
		||||
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_model.vision_model.pre_layrnorm, pt_weights, "visual.ln_pre")
 | 
			
		||||
    copy_linear(hf_model.vision_model.post_layernorm, pt_weights, "visual.ln_post")
 | 
			
		||||
 | 
			
		||||
    # copy embeddings
 | 
			
		||||
    hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_weights["visual.conv1.weight"].data
 | 
			
		||||
    hf_model.vision_model.embeddings.class_embedding.data = pt_weights["visual.class_embedding"].data
 | 
			
		||||
    hf_model.vision_model.embeddings.position_embedding.weight.data = pt_weights["visual.positional_embedding"].data
 | 
			
		||||
 | 
			
		||||
    # copy encoder
 | 
			
		||||
    copy_layers(hf_model.vision_model.encoder.layers, pt_weights, "visual.transformer.resblocks")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_chinese_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    assert config_path is not None, "Please specify the ChineseCLIP model config of the corresponding model size."
 | 
			
		||||
    config = ChineseCLIPConfig.from_pretrained(config_path)
 | 
			
		||||
 | 
			
		||||
    hf_model = ChineseCLIPModel(config).eval()
 | 
			
		||||
 | 
			
		||||
    pt_weights = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
 | 
			
		||||
    pt_weights = {(name[7:] if name.startswith("module.") else name): value for name, value in pt_weights.items()}
 | 
			
		||||
 | 
			
		||||
    copy_text_model_and_projection(hf_model, pt_weights)
 | 
			
		||||
    copy_vision_model_and_projection(hf_model, pt_weights)
 | 
			
		||||
    hf_model.logit_scale.data = pt_weights["logit_scale"].data
 | 
			
		||||
 | 
			
		||||
    hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output folder storing converted hf PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path", default=None, type=str, help="Path to original github format ChineseCLIP checkpoint."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--config_path", default=None, required=True, type=str, help="Path to hf config.json of model to convert."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_chinese_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
    print("The conversion is finished!")
 | 
			
		||||
@ -1,133 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
from laion_clap import CLAP_Module
 | 
			
		||||
 | 
			
		||||
from transformers import AutoFeatureExtractor, ClapConfig, ClapModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KEYS_TO_MODIFY_MAPPING = {
 | 
			
		||||
    "text_branch": "text_model",
 | 
			
		||||
    "audio_branch": "audio_model.audio_encoder",
 | 
			
		||||
    "attn": "attention.self",
 | 
			
		||||
    "self.proj": "output.dense",
 | 
			
		||||
    "attention.self_mask": "attn_mask",
 | 
			
		||||
    "mlp.fc1": "intermediate.dense",
 | 
			
		||||
    "mlp.fc2": "output.dense",
 | 
			
		||||
    "norm1": "layernorm_before",
 | 
			
		||||
    "norm2": "layernorm_after",
 | 
			
		||||
    "bn0": "batch_norm",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_clap(checkpoint_path, model_type, enable_fusion=False):
 | 
			
		||||
    model = CLAP_Module(
 | 
			
		||||
        amodel=model_type,
 | 
			
		||||
        enable_fusion=enable_fusion,
 | 
			
		||||
    )
 | 
			
		||||
    model.load_ckpt(checkpoint_path)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_config_from_original(clap_model):
 | 
			
		||||
    audio_config = {
 | 
			
		||||
        "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
 | 
			
		||||
        "depths": clap_model.model.audio_branch.depths,
 | 
			
		||||
        "hidden_size": clap_model.model.audio_projection[0].in_features,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}
 | 
			
		||||
 | 
			
		||||
    return ClapConfig(audio_config=audio_config, text_config=text_config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict(state_dict):
 | 
			
		||||
    model_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    sequential_layers_pattern = r".*sequential.(\d+).*"
 | 
			
		||||
    text_projection_pattern = r".*_projection.(\d+).*"
 | 
			
		||||
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        # check if any key needs to be modified
 | 
			
		||||
        for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
 | 
			
		||||
            if key_to_modify in key:
 | 
			
		||||
                key = key.replace(key_to_modify, new_key)
 | 
			
		||||
 | 
			
		||||
        if re.match(sequential_layers_pattern, key):
 | 
			
		||||
            # replace sequential layers with list
 | 
			
		||||
            sequential_layer = re.match(sequential_layers_pattern, key).group(1)
 | 
			
		||||
 | 
			
		||||
            key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
 | 
			
		||||
        elif re.match(text_projection_pattern, key):
 | 
			
		||||
            projecton_layer = int(re.match(text_projection_pattern, key).group(1))
 | 
			
		||||
 | 
			
		||||
            # Because in CLAP they use `nn.Sequential`...
 | 
			
		||||
            transformers_projection_layer = 1 if projecton_layer == 0 else 2
 | 
			
		||||
 | 
			
		||||
            key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
 | 
			
		||||
 | 
			
		||||
        if "audio" and "qkv" in key:
 | 
			
		||||
            # split qkv into query key and value
 | 
			
		||||
            mixed_qkv = value
 | 
			
		||||
            qkv_dim = mixed_qkv.size(0) // 3
 | 
			
		||||
 | 
			
		||||
            query_layer = mixed_qkv[:qkv_dim]
 | 
			
		||||
            key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
 | 
			
		||||
            value_layer = mixed_qkv[qkv_dim * 2 :]
 | 
			
		||||
 | 
			
		||||
            model_state_dict[key.replace("qkv", "query")] = query_layer
 | 
			
		||||
            model_state_dict[key.replace("qkv", "key")] = key_layer
 | 
			
		||||
            model_state_dict[key.replace("qkv", "value")] = value_layer
 | 
			
		||||
        else:
 | 
			
		||||
            model_state_dict[key] = value
 | 
			
		||||
 | 
			
		||||
    return model_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
 | 
			
		||||
    clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)
 | 
			
		||||
 | 
			
		||||
    clap_model.eval()
 | 
			
		||||
    state_dict = clap_model.model.state_dict()
 | 
			
		||||
    state_dict = rename_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    transformers_config = get_config_from_original(clap_model)
 | 
			
		||||
    transformers_config.audio_config.enable_fusion = enable_fusion
 | 
			
		||||
    model = ClapModel(transformers_config)
 | 
			
		||||
 | 
			
		||||
    # ignore the spectrogram embedding layer
 | 
			
		||||
    model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    transformers_config.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
 | 
			
		||||
    parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clap_checkpoint(
 | 
			
		||||
        args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
 | 
			
		||||
    )
 | 
			
		||||
@ -1,156 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from clip import load
 | 
			
		||||
 | 
			
		||||
from transformers import CLIPConfig, CLIPModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
 | 
			
		||||
    q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
 | 
			
		||||
    q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
 | 
			
		||||
 | 
			
		||||
    out_proj_weights = pt_attn_layer.out_proj.weight
 | 
			
		||||
    out_proj_bias = pt_attn_layer.out_proj.bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.q_proj.weight.data = q_proj
 | 
			
		||||
    hf_attn_layer.q_proj.bias.data = q_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.k_proj.weight.data = k_proj
 | 
			
		||||
    hf_attn_layer.k_proj.bias.data = k_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.v_proj.weight.data = v_proj
 | 
			
		||||
    hf_attn_layer.v_proj.bias.data = v_proj_bias
 | 
			
		||||
 | 
			
		||||
    hf_attn_layer.out_proj.weight = out_proj_weights
 | 
			
		||||
    hf_attn_layer.out_proj.bias = out_proj_bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_mlp(hf_mlp, pt_mlp):
 | 
			
		||||
    copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
 | 
			
		||||
    copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_linear(hf_linear, pt_linear):
 | 
			
		||||
    hf_linear.weight = pt_linear.weight
 | 
			
		||||
    hf_linear.bias = pt_linear.bias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layer(hf_layer, pt_layer):
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
 | 
			
		||||
    copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
 | 
			
		||||
 | 
			
		||||
    # copy MLP
 | 
			
		||||
    copy_mlp(hf_layer.mlp, pt_layer.mlp)
 | 
			
		||||
 | 
			
		||||
    # copy attn
 | 
			
		||||
    copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_layers(hf_layers, pt_layers):
 | 
			
		||||
    for hf_layer, pt_layer in zip(hf_layers, pt_layers):
 | 
			
		||||
        copy_layer(hf_layer, pt_layer)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_encoder(hf_encoder, pt_model):
 | 
			
		||||
    # copy  embeds
 | 
			
		||||
    hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
 | 
			
		||||
    hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
 | 
			
		||||
 | 
			
		||||
    # copy layer norm
 | 
			
		||||
    copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
 | 
			
		||||
 | 
			
		||||
    # copy hidden layers
 | 
			
		||||
    copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_text_model_and_projection(hf_model, pt_model):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.text_projection.weight.data = pt_model.text_projection.data.T.contiguous()
 | 
			
		||||
 | 
			
		||||
    # copy text encoder
 | 
			
		||||
    copy_encoder(hf_model.text_model, pt_model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_vison_model_and_projection(hf_model, pt_model):
 | 
			
		||||
    # copy projection
 | 
			
		||||
    hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T.contiguous()
 | 
			
		||||
 | 
			
		||||
    # copy layer norms
 | 
			
		||||
    copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
 | 
			
		||||
    copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
 | 
			
		||||
 | 
			
		||||
    # copy embeds
 | 
			
		||||
    hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
 | 
			
		||||
    hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
 | 
			
		||||
    hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
 | 
			
		||||
 | 
			
		||||
    # copy encoder
 | 
			
		||||
    copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if config_path is not None:
 | 
			
		||||
        config = CLIPConfig.from_pretrained(config_path)
 | 
			
		||||
    else:
 | 
			
		||||
        config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
 | 
			
		||||
 | 
			
		||||
    hf_model = CLIPModel(config).eval()
 | 
			
		||||
 | 
			
		||||
    pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
 | 
			
		||||
    pt_model = pt_model.eval()
 | 
			
		||||
 | 
			
		||||
    copy_text_model_and_projection(hf_model, pt_model)
 | 
			
		||||
    copy_vison_model_and_projection(hf_model, pt_model)
 | 
			
		||||
    hf_model.logit_scale = pt_model.logit_scale
 | 
			
		||||
 | 
			
		||||
    # Use `eos_token` so the example is more meaningful
 | 
			
		||||
    input_ids = torch.tensor(
 | 
			
		||||
        [
 | 
			
		||||
            [config.text_config.bos_token_id]
 | 
			
		||||
            + list(range(3, 77))
 | 
			
		||||
            + [config.text_config.eos_token_id]
 | 
			
		||||
            + [config.text_config.pad_token_id]
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    pixel_values = torch.randn(1, 3, 224, 224)
 | 
			
		||||
 | 
			
		||||
    hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
 | 
			
		||||
    hf_logits_per_image = hf_outputs.logits_per_image
 | 
			
		||||
    hf_logits_per_text = hf_outputs.logits_per_text
 | 
			
		||||
    pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
 | 
			
		||||
 | 
			
		||||
    hf_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to OpenAI checkpoint")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)
 | 
			
		||||
@ -1,264 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert CLIPSeg checkpoints from the original repository. URL: https://github.com/timojl/clipseg."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    CLIPSegConfig,
 | 
			
		||||
    CLIPSegForImageSegmentation,
 | 
			
		||||
    CLIPSegProcessor,
 | 
			
		||||
    CLIPSegTextConfig,
 | 
			
		||||
    CLIPSegVisionConfig,
 | 
			
		||||
    CLIPTokenizer,
 | 
			
		||||
    ViTImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_clipseg_config(model_name):
 | 
			
		||||
    text_config = CLIPSegTextConfig()
 | 
			
		||||
    vision_config = CLIPSegVisionConfig(patch_size=16)
 | 
			
		||||
 | 
			
		||||
    use_complex_transposed_convolution = True if "refined" in model_name else False
 | 
			
		||||
    reduce_dim = 16 if "rd16" in model_name else 64
 | 
			
		||||
 | 
			
		||||
    config = CLIPSegConfig.from_text_vision_configs(
 | 
			
		||||
        text_config,
 | 
			
		||||
        vision_config,
 | 
			
		||||
        use_complex_transposed_convolution=use_complex_transposed_convolution,
 | 
			
		||||
        reduce_dim=reduce_dim,
 | 
			
		||||
    )
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    # update prefixes
 | 
			
		||||
    if "clip_model" in name:
 | 
			
		||||
        name = name.replace("clip_model", "clip")
 | 
			
		||||
    if "transformer" in name:
 | 
			
		||||
        if "visual" in name:
 | 
			
		||||
            name = name.replace("visual.transformer", "vision_model")
 | 
			
		||||
        else:
 | 
			
		||||
            name = name.replace("transformer", "text_model")
 | 
			
		||||
    if "resblocks" in name:
 | 
			
		||||
        name = name.replace("resblocks", "encoder.layers")
 | 
			
		||||
    if "ln_1" in name:
 | 
			
		||||
        name = name.replace("ln_1", "layer_norm1")
 | 
			
		||||
    if "ln_2" in name:
 | 
			
		||||
        name = name.replace("ln_2", "layer_norm2")
 | 
			
		||||
    if "c_fc" in name:
 | 
			
		||||
        name = name.replace("c_fc", "fc1")
 | 
			
		||||
    if "c_proj" in name:
 | 
			
		||||
        name = name.replace("c_proj", "fc2")
 | 
			
		||||
    if "attn" in name and "self" not in name:
 | 
			
		||||
        name = name.replace("attn", "self_attn")
 | 
			
		||||
    # text encoder
 | 
			
		||||
    if "token_embedding" in name:
 | 
			
		||||
        name = name.replace("token_embedding", "text_model.embeddings.token_embedding")
 | 
			
		||||
    if "positional_embedding" in name and "visual" not in name:
 | 
			
		||||
        name = name.replace("positional_embedding", "text_model.embeddings.position_embedding.weight")
 | 
			
		||||
    if "ln_final" in name:
 | 
			
		||||
        name = name.replace("ln_final", "text_model.final_layer_norm")
 | 
			
		||||
    # vision encoder
 | 
			
		||||
    if "visual.class_embedding" in name:
 | 
			
		||||
        name = name.replace("visual.class_embedding", "vision_model.embeddings.class_embedding")
 | 
			
		||||
    if "visual.conv1" in name:
 | 
			
		||||
        name = name.replace("visual.conv1", "vision_model.embeddings.patch_embedding")
 | 
			
		||||
    if "visual.positional_embedding" in name:
 | 
			
		||||
        name = name.replace("visual.positional_embedding", "vision_model.embeddings.position_embedding.weight")
 | 
			
		||||
    if "visual.ln_pre" in name:
 | 
			
		||||
        name = name.replace("visual.ln_pre", "vision_model.pre_layrnorm")
 | 
			
		||||
    if "visual.ln_post" in name:
 | 
			
		||||
        name = name.replace("visual.ln_post", "vision_model.post_layernorm")
 | 
			
		||||
    # projection layers
 | 
			
		||||
    if "visual.proj" in name:
 | 
			
		||||
        name = name.replace("visual.proj", "visual_projection.weight")
 | 
			
		||||
    if "text_projection" in name:
 | 
			
		||||
        name = name.replace("text_projection", "text_projection.weight")
 | 
			
		||||
    # decoder
 | 
			
		||||
    if "trans_conv" in name:
 | 
			
		||||
        name = name.replace("trans_conv", "transposed_convolution")
 | 
			
		||||
    if "film_mul" in name or "film_add" in name or "reduce" in name or "transposed_convolution" in name:
 | 
			
		||||
        name = "decoder." + name
 | 
			
		||||
    if "blocks" in name:
 | 
			
		||||
        name = name.replace("blocks", "decoder.layers")
 | 
			
		||||
    if "linear1" in name:
 | 
			
		||||
        name = name.replace("linear1", "mlp.fc1")
 | 
			
		||||
    if "linear2" in name:
 | 
			
		||||
        name = name.replace("linear2", "mlp.fc2")
 | 
			
		||||
    if "norm1" in name and "layer_" not in name:
 | 
			
		||||
        name = name.replace("norm1", "layer_norm1")
 | 
			
		||||
    if "norm2" in name and "layer_" not in name:
 | 
			
		||||
        name = name.replace("norm2", "layer_norm2")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_state_dict(orig_state_dict, config):
 | 
			
		||||
    for key in orig_state_dict.copy().keys():
 | 
			
		||||
        val = orig_state_dict.pop(key)
 | 
			
		||||
 | 
			
		||||
        if key.startswith("clip_model") and "attn.in_proj" in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            if "visual" in key:
 | 
			
		||||
                layer_num = int(key_split[4])
 | 
			
		||||
                dim = config.vision_config.hidden_size
 | 
			
		||||
                prefix = "vision_model"
 | 
			
		||||
            else:
 | 
			
		||||
                layer_num = int(key_split[3])
 | 
			
		||||
                dim = config.text_config.hidden_size
 | 
			
		||||
                prefix = "text_model"
 | 
			
		||||
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[
 | 
			
		||||
                    dim : dim * 2, :
 | 
			
		||||
                ]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[f"clip.{prefix}.encoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
 | 
			
		||||
        elif "self_attn" in key and "out_proj" not in key:
 | 
			
		||||
            key_split = key.split(".")
 | 
			
		||||
            layer_num = int(key_split[1])
 | 
			
		||||
            dim = config.reduce_dim
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.weight"] = val[:dim, :]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.weight"] = val[dim : dim * 2, :]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.weight"] = val[-dim:, :]
 | 
			
		||||
            else:
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.q_proj.bias"] = val[:dim]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.k_proj.bias"] = val[dim : dim * 2]
 | 
			
		||||
                orig_state_dict[f"decoder.layers.{layer_num}.self_attn.v_proj.bias"] = val[-dim:]
 | 
			
		||||
        else:
 | 
			
		||||
            new_name = rename_key(key)
 | 
			
		||||
            if "visual_projection" in new_name or "text_projection" in new_name:
 | 
			
		||||
                val = val.T
 | 
			
		||||
            orig_state_dict[new_name] = val
 | 
			
		||||
 | 
			
		||||
    return orig_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    image = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clipseg_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    config = get_clipseg_config(model_name)
 | 
			
		||||
    model = CLIPSegForImageSegmentation(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    state_dict = torch.load(checkpoint_path, map_location="cpu")
 | 
			
		||||
 | 
			
		||||
    # remove some keys
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if key.startswith("model"):
 | 
			
		||||
            state_dict.pop(key, None)
 | 
			
		||||
 | 
			
		||||
    # rename some keys
 | 
			
		||||
    state_dict = convert_state_dict(state_dict, config)
 | 
			
		||||
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
 | 
			
		||||
    if missing_keys != ["clip.text_model.embeddings.position_ids", "clip.vision_model.embeddings.position_ids"]:
 | 
			
		||||
        raise ValueError("Missing keys that are not expected: {}".format(missing_keys))
 | 
			
		||||
    if unexpected_keys != ["decoder.reduce.weight", "decoder.reduce.bias"]:
 | 
			
		||||
        raise ValueError(f"Unexpected keys: {unexpected_keys}")
 | 
			
		||||
 | 
			
		||||
    image_processor = ViTImageProcessor(size=352)
 | 
			
		||||
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
 | 
			
		||||
    processor = CLIPSegProcessor(image_processor=image_processor, tokenizer=tokenizer)
 | 
			
		||||
 | 
			
		||||
    image = prepare_img()
 | 
			
		||||
    text = ["a glass", "something to fill", "wood", "a jar"]
 | 
			
		||||
 | 
			
		||||
    inputs = processor(text=text, images=[image] * len(text), padding="max_length", return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        outputs = model(**inputs)
 | 
			
		||||
 | 
			
		||||
    # verify values
 | 
			
		||||
    expected_conditional = torch.tensor([0.1110, -0.1882, 0.1645])
 | 
			
		||||
    expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328])
 | 
			
		||||
    if model_name == "clipseg-rd64-refined":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-10.0407, -9.9431, -10.2646], [-9.9751, -9.7064, -9.9586], [-9.6891, -9.5645, -9.9618]]
 | 
			
		||||
        )
 | 
			
		||||
    elif model_name == "clipseg-rd64":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-7.2877, -7.2711, -7.2463], [-7.2652, -7.2780, -7.2520], [-7.2239, -7.2204, -7.2001]]
 | 
			
		||||
        )
 | 
			
		||||
    elif model_name == "clipseg-rd16":
 | 
			
		||||
        expected_masks_slice = torch.tensor(
 | 
			
		||||
            [[-6.3955, -6.4055, -6.4151], [-6.3911, -6.4033, -6.4100], [-6.3474, -6.3702, -6.3762]]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Model name {model_name} not supported.")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)
 | 
			
		||||
    assert torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)
 | 
			
		||||
    print("Looks ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path is not None:
 | 
			
		||||
        print(f"Saving model and processor to {pytorch_dump_folder_path}")
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing model and processor for {model_name} to the hub")
 | 
			
		||||
        model.push_to_hub(f"CIDAS/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"CIDAS/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="clipseg-rd64",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["clipseg-rd16", "clipseg-rd64", "clipseg-rd64-refined"],
 | 
			
		||||
        help=(
 | 
			
		||||
            "Name of the model. Supported models are: clipseg-rd64, clipseg-rd16 and clipseg-rd64-refined (rd meaning"
 | 
			
		||||
            " reduce dimension)"
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path",
 | 
			
		||||
        default="/Users/nielsrogge/Documents/CLIPSeg/clip_plus_rd64-uni.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help=(
 | 
			
		||||
            "Path to the original checkpoint. Note that the script assumes that the checkpoint includes both CLIP and"
 | 
			
		||||
            " the decoder weights."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_clipseg_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,234 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Weights conversion script for CLVP
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import ClvpConfig, ClvpModelForConditionalGeneration
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_MODELS = {
 | 
			
		||||
    "clvp": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/clvp2.pth",
 | 
			
		||||
    "decoder": "https://huggingface.co/jbetker/tortoise-tts-v2/blob/main/.models/autoregressive.pth",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
dim = 1024
 | 
			
		||||
sub_dim = dim // 16
 | 
			
		||||
 | 
			
		||||
CLVP_ENCODERS_MAPPING = {
 | 
			
		||||
    "text_transformer.transformer.attn_layers": "text_encoder_model",
 | 
			
		||||
    "speech_transformer.transformer.attn_layers": "speech_encoder_model",
 | 
			
		||||
    "text_transformer.transformer.norm": "text_encoder_model.final_layer_norm",
 | 
			
		||||
    "speech_transformer.transformer.norm": "speech_encoder_model.final_layer_norm",
 | 
			
		||||
    "to_text_latent": "text_encoder_model.projection",
 | 
			
		||||
    "to_speech_latent": "speech_encoder_model.projection",
 | 
			
		||||
    "text_emb": "text_encoder_model.token_embedding",
 | 
			
		||||
    "speech_emb": "speech_encoder_model.token_embedding",
 | 
			
		||||
    "1.wrap.net.0": "mlp.fc1",
 | 
			
		||||
    "1.wrap.net.3": "mlp.fc2",
 | 
			
		||||
    "1.wrap": "self_attn",
 | 
			
		||||
    "to_out": "out_proj",
 | 
			
		||||
    "to_q": "q_proj",
 | 
			
		||||
    "to_k": "k_proj",
 | 
			
		||||
    "to_v": "v_proj",
 | 
			
		||||
    "temperature": "logit_scale",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
CLVP_DECODER_MAPPING = {
 | 
			
		||||
    "conditioning_encoder.init": "conditioning_encoder.mel_conv",
 | 
			
		||||
    "conditioning_encoder.attn": "conditioning_encoder.mel_attn_blocks",
 | 
			
		||||
    "mel_attn_blocks": "group_norms",
 | 
			
		||||
    ".norm.weight": ".weight",
 | 
			
		||||
    ".norm.bias": ".bias",
 | 
			
		||||
    "text_embedding": "conditioning_encoder.text_token_embedding",
 | 
			
		||||
    "text_pos_embedding.emb": "conditioning_encoder.text_position_embedding",
 | 
			
		||||
    "final_norm": "speech_decoder_model.final_norm",
 | 
			
		||||
    "mel_head": "speech_decoder_model.lm_head",
 | 
			
		||||
    "gpt.ln_f": "speech_decoder_model.model.decoder.layer_norm",
 | 
			
		||||
    "mel_embedding": "speech_decoder_model.model.decoder.input_embeds_layer",
 | 
			
		||||
    "mel_pos_embedding.emb": "speech_decoder_model.model.decoder.position_embeds_layer",
 | 
			
		||||
    "gpt.h": "speech_decoder_model.model.decoder.layers",
 | 
			
		||||
    "ln_1": "input_layernorm",
 | 
			
		||||
    "ln_2": "post_attention_layernorm",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_index(present_index):
 | 
			
		||||
    if present_index % 2 == 0:
 | 
			
		||||
        return int(present_index / 2)
 | 
			
		||||
    else:
 | 
			
		||||
        return int((present_index - 1) / 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_encoder_weights(original_weights):
 | 
			
		||||
    converted_weights = {}
 | 
			
		||||
    original_weights_keys = sorted(original_weights.keys())
 | 
			
		||||
    for original_key in original_weights_keys:
 | 
			
		||||
        updated_key = original_key
 | 
			
		||||
        # for input_rmsnorm.weight and post_attention_rmsnorm.weight
 | 
			
		||||
        if "0.0.g" in updated_key:
 | 
			
		||||
            present_index = updated_key.split(".")[4]
 | 
			
		||||
            if int(present_index) % 2 == 0:
 | 
			
		||||
                updated_key = updated_key.replace("0.0.g", "input_rmsnorm.weight")
 | 
			
		||||
            else:
 | 
			
		||||
                updated_key = updated_key.replace("0.0.g", "post_attention_rmsnorm.weight")
 | 
			
		||||
 | 
			
		||||
        if "transformer.attn_layers.layers" in updated_key:
 | 
			
		||||
            present_index = updated_key.split(".")[4]
 | 
			
		||||
            updated_index = update_index(int(present_index))
 | 
			
		||||
            updated_key = updated_key.replace(
 | 
			
		||||
                f"transformer.attn_layers.layers.{present_index}", f"transformer.attn_layers.layers.{updated_index}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for k, v in CLVP_ENCODERS_MAPPING.items():
 | 
			
		||||
            if k in updated_key:
 | 
			
		||||
                updated_key = updated_key.replace(k, v)
 | 
			
		||||
 | 
			
		||||
        converted_weights[updated_key] = original_weights.pop(original_key)
 | 
			
		||||
 | 
			
		||||
    return converted_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_decoder_weights(original_weights):
 | 
			
		||||
    converted_weights = {}
 | 
			
		||||
    original_weights_keys = sorted(original_weights.keys())
 | 
			
		||||
    for original_key in original_weights_keys:
 | 
			
		||||
        updated_key = original_key
 | 
			
		||||
        if len(updated_key.split(".")) > 3:
 | 
			
		||||
            index, attr = updated_key.split(".")[2], updated_key.split(".")[-1]
 | 
			
		||||
 | 
			
		||||
        # for decoder attention
 | 
			
		||||
        if "attn.c_attn" in updated_key:
 | 
			
		||||
            if attr == "weight":
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).T.split(split_size=dim, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.q_proj.{attr}"] = slice1
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.k_proj.{attr}"] = slice2
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.v_proj.{attr}"] = slice3
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "attn.c_proj" in updated_key:
 | 
			
		||||
            converted_weights[f"speech_decoder_model.model.decoder.layers.{index}.attn.out_proj.{attr}"] = (
 | 
			
		||||
                original_weights[updated_key].squeeze(-1).T
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "attn.bias" in updated_key or "attn.masked_bias" in updated_key or "text_head" in updated_key:
 | 
			
		||||
            original_weights.pop(updated_key)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # conditional encoder attention
 | 
			
		||||
        if "qkv" in updated_key:
 | 
			
		||||
            if attr == "weight":
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].squeeze(-1).split(split_size=dim, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                slice1, slice2, slice3 = original_weights[updated_key].split(split_size=dim, dim=0)
 | 
			
		||||
 | 
			
		||||
            indices = torch.arange(dim)
 | 
			
		||||
            index1, index2, index3 = (
 | 
			
		||||
                indices.unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
                indices[sub_dim:].unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
                indices[2 * sub_dim :].unfold(0, sub_dim, sub_dim * 3).flatten(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.q_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index1], slice2[index3], slice3[index2]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.k_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index2], slice2[index1], slice3[index3]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.v_proj.{attr}"] = torch.concatenate(
 | 
			
		||||
                [slice1[index3], slice2[index2], slice3[index1]],
 | 
			
		||||
                axis=0,
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if "proj_out" in updated_key:
 | 
			
		||||
            converted_weights[f"conditioning_encoder.mel_attn_blocks.{index}.out_proj.{attr}"] = original_weights[
 | 
			
		||||
                updated_key
 | 
			
		||||
            ].squeeze(-1)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        for k, v in CLVP_DECODER_MAPPING.items():
 | 
			
		||||
            if k in updated_key:
 | 
			
		||||
                updated_key = updated_key.replace(k, v)
 | 
			
		||||
 | 
			
		||||
        converted_weights[updated_key] = original_weights.pop(original_key)
 | 
			
		||||
 | 
			
		||||
    return converted_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _download(url: str, root: str):
 | 
			
		||||
    repo_id = f"{url.split('/')[3]}/{url.split('/')[4]}"
 | 
			
		||||
    filename = f"{url.split('/')[-2]}/{url.split('/')[-1]}"
 | 
			
		||||
    hf_hub_download(
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        filename=filename,
 | 
			
		||||
        force_filename=root,
 | 
			
		||||
        local_dir_use_symlinks=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_clvp_weights(checkpoint_path, pytorch_dump_folder_path):
 | 
			
		||||
    converted_checkpoint = {}
 | 
			
		||||
 | 
			
		||||
    for each_model_name, each_model_url in _MODELS.items():
 | 
			
		||||
        each_model_path = os.path.join(checkpoint_path, each_model_url.split("/")[-1])
 | 
			
		||||
        if not os.path.exists(each_model_path):
 | 
			
		||||
            print(f"\n{each_model_name} was not found! Downloading it to {each_model_path}")
 | 
			
		||||
            _download(url=each_model_url, root=each_model_path)
 | 
			
		||||
 | 
			
		||||
        if each_model_name == "clvp":
 | 
			
		||||
            clvp_checkpoint = torch.load(each_model_path, map_location="cpu")
 | 
			
		||||
        else:
 | 
			
		||||
            decoder_checkpoint = torch.load(each_model_path, map_location="cpu")
 | 
			
		||||
 | 
			
		||||
    # Converting the weights
 | 
			
		||||
    converted_checkpoint.update(**convert_encoder_weights(clvp_checkpoint))
 | 
			
		||||
    converted_checkpoint.update(**convert_decoder_weights(decoder_checkpoint))
 | 
			
		||||
 | 
			
		||||
    config = ClvpConfig.from_pretrained("susnato/clvp_dev")
 | 
			
		||||
    model = ClvpModelForConditionalGeneration(config)
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(converted_checkpoint, strict=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Model saved at {pytorch_dump_folder_path}!")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path", type=str, help="Path to the folder of downloaded checkpoints. (Please enter full path)"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model. (Please enter full path)",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_clvp_weights(args.checkpoint_path, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,214 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""
 | 
			
		||||
Convert ColPali weights from the original repository to the HF model format.
 | 
			
		||||
 | 
			
		||||
Original repository: https://github.com/illuin-tech/colpali.
 | 
			
		||||
 | 
			
		||||
NOTE: This script was originally run using `torch==2.5.1` and with:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
    --model_id vidore/colpali-v1.2-merged \
 | 
			
		||||
    --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
 | 
			
		||||
    --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
    --output_dir vidore/colpali-v1.2-hf-internal \
 | 
			
		||||
    --push_to_hub
 | 
			
		||||
 | 
			
		||||
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
    --model_id vidore/colpali-v1.3-merged \
 | 
			
		||||
    --revision 5b955e3415a7c5468ab33119d98d6d45c3a5b2c3 \
 | 
			
		||||
    --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
    --output_dir vidore/colpali-v1.3-hf \
 | 
			
		||||
    --push_to_hub
 | 
			
		||||
```
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import glob
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from safetensors import safe_open
 | 
			
		||||
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
from transformers.models.colpali import ColPaliForRetrieval
 | 
			
		||||
from transformers.models.colpali.configuration_colpali import ColPaliConfig
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ORIGINAL_DTYPE = torch.bfloat16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
 | 
			
		||||
    new_state_dict = {}
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        new_key = key
 | 
			
		||||
        if key.startswith("custom_text_proj"):
 | 
			
		||||
            new_key = key.replace("custom_text_proj", "embedding_proj_layer")
 | 
			
		||||
        if key.startswith("model."):
 | 
			
		||||
            new_key = key.replace("model.", "vlm.", 1)
 | 
			
		||||
        new_state_dict[new_key] = value
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
 | 
			
		||||
    directory_path = snapshot_download(
 | 
			
		||||
        repo_id=model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        allow_patterns=["*.safetensors"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    original_state_dict = {}
 | 
			
		||||
    for path in glob.glob(f"{directory_path}/*"):
 | 
			
		||||
        if path.endswith(".safetensors"):
 | 
			
		||||
            with safe_open(path, framework="pt", device="cpu") as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    original_state_dict[key] = f.get_tensor(key)
 | 
			
		||||
 | 
			
		||||
    # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
 | 
			
		||||
    if "lm_head.weight" not in original_state_dict:
 | 
			
		||||
        original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
 | 
			
		||||
            "model.language_model.model.embed_tokens.weight"
 | 
			
		||||
        ].clone()
 | 
			
		||||
 | 
			
		||||
    return original_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_colpali_weights_to_hf(
 | 
			
		||||
    model_id: str,
 | 
			
		||||
    output_dir: str,
 | 
			
		||||
    push_to_hub: bool,
 | 
			
		||||
    revision: Optional[str] = None,
 | 
			
		||||
    original_vlm_name_or_path: Optional[str] = None,
 | 
			
		||||
):
 | 
			
		||||
    # Load the original model data
 | 
			
		||||
    original_config = AutoConfig.from_pretrained(
 | 
			
		||||
        model_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
    )
 | 
			
		||||
    if original_vlm_name_or_path is not None:
 | 
			
		||||
        original_config._name_or_path = original_vlm_name_or_path
 | 
			
		||||
    if hasattr(original_config, "architectures"):
 | 
			
		||||
        delattr(original_config, "architectures")
 | 
			
		||||
 | 
			
		||||
    original_state_dict = load_original_state_dict(model_id, revision=revision)
 | 
			
		||||
 | 
			
		||||
    # Format the state_dict keys
 | 
			
		||||
    original_state_dict = rename_state_dict_keys(original_state_dict)
 | 
			
		||||
 | 
			
		||||
    # Create the new config
 | 
			
		||||
    config = ColPaliConfig(
 | 
			
		||||
        vlm_config=original_config,
 | 
			
		||||
        embedding_dim=128,  # hardcoded in the original model
 | 
			
		||||
    )
 | 
			
		||||
    config.model_type = "colpali"
 | 
			
		||||
    config.is_composition = False
 | 
			
		||||
 | 
			
		||||
    # Load the untrained model
 | 
			
		||||
    model = ColPaliForRetrieval(config=config).to("cpu").eval()
 | 
			
		||||
    print("Created model with new config and randomly initialized weights")
 | 
			
		||||
 | 
			
		||||
    # NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
 | 
			
		||||
    # There are two ways to set the model's dtype:
 | 
			
		||||
    # - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
 | 
			
		||||
    # - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
 | 
			
		||||
    # The following snippet allows a fine-grained control over the model's dtype, making sure that all
 | 
			
		||||
    # the new weights' dtypes match the original model.
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        param.data = param.data.to(ORIGINAL_DTYPE)
 | 
			
		||||
    print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
 | 
			
		||||
 | 
			
		||||
    # Load the original weights
 | 
			
		||||
    model.load_state_dict(original_state_dict)
 | 
			
		||||
    print("Loaded original model weights")
 | 
			
		||||
 | 
			
		||||
    # Tie the weights (following ColPali's `__init__`` step)
 | 
			
		||||
    if model.vlm.language_model._tied_weights_keys is not None:
 | 
			
		||||
        model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
 | 
			
		||||
 | 
			
		||||
    # Sanity check: ensure all keys are the same
 | 
			
		||||
    state_dict_keys_old = set(original_state_dict.keys())
 | 
			
		||||
    state_dict_keys_new = set(model.state_dict().keys())
 | 
			
		||||
    disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
 | 
			
		||||
    if disjoint_keys:
 | 
			
		||||
        raise ValueError(f"Incompatible keys: {disjoint_keys}")
 | 
			
		||||
 | 
			
		||||
    # Save the model
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub(output_dir, private=True)
 | 
			
		||||
        print(f"Model pushed to the hub at `{output_dir}`")
 | 
			
		||||
    else:
 | 
			
		||||
        Path(output_dir).mkdir(exist_ok=True, parents=True)
 | 
			
		||||
        model.save_pretrained(output_dir)
 | 
			
		||||
        print(f"Model saved to `{output_dir}`")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="""
 | 
			
		||||
        This script converts the original ColPali model to the HF model format.
 | 
			
		||||
 | 
			
		||||
        Example usage:
 | 
			
		||||
        ```bash
 | 
			
		||||
        python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
 | 
			
		||||
            --model_id vidore/colpali-v1.2-merged \
 | 
			
		||||
            --revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
 | 
			
		||||
            --original_vlm_name_or_path google/paligemma-3b-mix-448 \
 | 
			
		||||
            --output_dir vidore/colpali-v1.2-hf \
 | 
			
		||||
            --push_to_hub
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_id",
 | 
			
		||||
        help="Model ID of the original model to convert",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--output_dir",
 | 
			
		||||
        help="Location to write HF model and tokenizer",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        help="Revision of the model to download",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--original_vlm_name_or_path",
 | 
			
		||||
        help="Name or path of the original VLM backbone model",
 | 
			
		||||
        default=None,
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_colpali_weights_to_hf(
 | 
			
		||||
        model_id=args.model_id,
 | 
			
		||||
        output_dir=args.output_dir,
 | 
			
		||||
        push_to_hub=args.push_to_hub,
 | 
			
		||||
        revision=args.revision,
 | 
			
		||||
        original_vlm_name_or_path=args.original_vlm_name_or_path,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,324 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Conditional DETR checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    ConditionalDetrConfig,
 | 
			
		||||
    ConditionalDetrForObjectDetection,
 | 
			
		||||
    ConditionalDetrForSegmentation,
 | 
			
		||||
    ConditionalDetrImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
rename_keys = []
 | 
			
		||||
for i in range(6):
 | 
			
		||||
    # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.self_attn.out_proj.weight", f"encoder.layers.{i}.self_attn.out_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
    # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"decoder.layers.{i}.self_attn.out_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"transformer.decoder.layers.{i}.cross_attn.out_proj.weight",
 | 
			
		||||
            f"decoder.layers.{i}.encoder_attn.out_proj.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"transformer.decoder.layers.{i}.cross_attn.out_proj.bias",
 | 
			
		||||
            f"decoder.layers.{i}.encoder_attn.out_proj.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # q, k, v projections in self/cross-attention in decoder for conditional DETR
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qcontent_proj.weight", f"decoder.layers.{i}.sa_qcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kcontent_proj.weight", f"decoder.layers.{i}.sa_kcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qpos_proj.weight", f"decoder.layers.{i}.sa_qpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kpos_proj.weight", f"decoder.layers.{i}.sa_kpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.weight", f"decoder.layers.{i}.sa_v_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qcontent_proj.weight", f"decoder.layers.{i}.ca_qcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.weight", f"decoder.layers.{i}.ca_qpos_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kcontent_proj.weight", f"decoder.layers.{i}.ca_kcontent_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kpos_proj.weight", f"decoder.layers.{i}.ca_kpos_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.weight", f"decoder.layers.{i}.ca_v_proj.weight"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.weight", f"decoder.layers.{i}.ca_qpos_sine_proj.weight")
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_qcontent_proj.bias", f"decoder.layers.{i}.sa_qcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.sa_kcontent_proj.bias", f"decoder.layers.{i}.sa_kcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_qpos_proj.bias", f"decoder.layers.{i}.sa_qpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_kpos_proj.bias", f"decoder.layers.{i}.sa_kpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.sa_v_proj.bias", f"decoder.layers.{i}.sa_v_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qcontent_proj.bias", f"decoder.layers.{i}.ca_qcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    # rename_keys.append((f"transformer.decoder.layers.{i}.ca_qpos_proj.bias", f"decoder.layers.{i}.ca_qpos_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_kcontent_proj.bias", f"decoder.layers.{i}.ca_kcontent_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_kpos_proj.bias", f"decoder.layers.{i}.ca_kpos_proj.bias"))
 | 
			
		||||
    rename_keys.append((f"transformer.decoder.layers.{i}.ca_v_proj.bias", f"decoder.layers.{i}.ca_v_proj.bias"))
 | 
			
		||||
    rename_keys.append(
 | 
			
		||||
        (f"transformer.decoder.layers.{i}.ca_qpos_sine_proj.bias", f"decoder.layers.{i}.ca_qpos_sine_proj.bias")
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
 | 
			
		||||
# for conditional DETR, also convert reference point head and query scale MLP
 | 
			
		||||
rename_keys.extend(
 | 
			
		||||
    [
 | 
			
		||||
        ("input_proj.weight", "input_projection.weight"),
 | 
			
		||||
        ("input_proj.bias", "input_projection.bias"),
 | 
			
		||||
        ("query_embed.weight", "query_position_embeddings.weight"),
 | 
			
		||||
        ("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
 | 
			
		||||
        ("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
 | 
			
		||||
        ("class_embed.weight", "class_labels_classifier.weight"),
 | 
			
		||||
        ("class_embed.bias", "class_labels_classifier.bias"),
 | 
			
		||||
        ("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
 | 
			
		||||
        ("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
 | 
			
		||||
        ("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
 | 
			
		||||
        ("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
 | 
			
		||||
        ("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
 | 
			
		||||
        ("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.0.weight", "decoder.ref_point_head.layers.0.weight"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.0.bias", "decoder.ref_point_head.layers.0.bias"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.1.weight", "decoder.ref_point_head.layers.1.weight"),
 | 
			
		||||
        ("transformer.decoder.ref_point_head.layers.1.bias", "decoder.ref_point_head.layers.1.bias"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.0.weight", "decoder.query_scale.layers.0.weight"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.0.bias", "decoder.query_scale.layers.0.bias"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.1.weight", "decoder.query_scale.layers.1.weight"),
 | 
			
		||||
        ("transformer.decoder.query_scale.layers.1.bias", "decoder.query_scale.layers.1.bias"),
 | 
			
		||||
        ("transformer.decoder.layers.0.ca_qpos_proj.weight", "decoder.layers.0.ca_qpos_proj.weight"),
 | 
			
		||||
        ("transformer.decoder.layers.0.ca_qpos_proj.bias", "decoder.layers.0.ca_qpos_proj.bias"),
 | 
			
		||||
    ]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(state_dict, old, new):
 | 
			
		||||
    val = state_dict.pop(old)
 | 
			
		||||
    state_dict[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_backbone_keys(state_dict):
 | 
			
		||||
    new_state_dict = OrderedDict()
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
        if "backbone.0.body" in key:
 | 
			
		||||
            new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model")
 | 
			
		||||
            new_state_dict[new_key] = value
 | 
			
		||||
        else:
 | 
			
		||||
            new_state_dict[key] = value
 | 
			
		||||
 | 
			
		||||
    return new_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_k_v(state_dict, is_panoptic=False):
 | 
			
		||||
    prefix = ""
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        prefix = "conditional_detr."
 | 
			
		||||
 | 
			
		||||
    # first: transformer encoder
 | 
			
		||||
    for i in range(6):
 | 
			
		||||
        # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
 | 
			
		||||
        state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our CONDITIONAL_DETR structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load default config
 | 
			
		||||
    config = ConditionalDetrConfig()
 | 
			
		||||
    # set backbone and dilation attributes
 | 
			
		||||
    if "resnet101" in model_name:
 | 
			
		||||
        config.backbone = "resnet101"
 | 
			
		||||
    if "dc5" in model_name:
 | 
			
		||||
        config.dilation = True
 | 
			
		||||
    is_panoptic = "panoptic" in model_name
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        config.num_labels = 250
 | 
			
		||||
    else:
 | 
			
		||||
        config.num_labels = 91
 | 
			
		||||
        repo_id = "huggingface/label-files"
 | 
			
		||||
        filename = "coco-detection-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    # load image processor
 | 
			
		||||
    format = "coco_panoptic" if is_panoptic else "coco_detection"
 | 
			
		||||
    image_processor = ConditionalDetrImageProcessor(format=format)
 | 
			
		||||
 | 
			
		||||
    # prepare image
 | 
			
		||||
    img = prepare_img()
 | 
			
		||||
    encoding = image_processor(images=img, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Converting model {model_name}...")
 | 
			
		||||
 | 
			
		||||
    # load original model from torch hub
 | 
			
		||||
    conditional_detr = torch.hub.load("DeppMeng/ConditionalDETR", model_name, pretrained=True).eval()
 | 
			
		||||
    state_dict = conditional_detr.state_dict()
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        if is_panoptic:
 | 
			
		||||
            src = "conditional_detr." + src
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    state_dict = rename_backbone_keys(state_dict)
 | 
			
		||||
    # query, key and value matrices need special treatment
 | 
			
		||||
    read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
 | 
			
		||||
    # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
 | 
			
		||||
    prefix = "conditional_detr.model." if is_panoptic else "model."
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if is_panoptic:
 | 
			
		||||
            if (
 | 
			
		||||
                key.startswith("conditional_detr")
 | 
			
		||||
                and not key.startswith("class_labels_classifier")
 | 
			
		||||
                and not key.startswith("bbox_predictor")
 | 
			
		||||
            ):
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict["conditional_detr.model" + key[4:]] = val
 | 
			
		||||
            elif "class_labels_classifier" in key or "bbox_predictor" in key:
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict["conditional_detr." + key] = val
 | 
			
		||||
            elif key.startswith("bbox_attention") or key.startswith("mask_head"):
 | 
			
		||||
                continue
 | 
			
		||||
            else:
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict[prefix + key] = val
 | 
			
		||||
        else:
 | 
			
		||||
            if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
 | 
			
		||||
                val = state_dict.pop(key)
 | 
			
		||||
                state_dict[prefix + key] = val
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = ConditionalDetrForSegmentation(config) if is_panoptic else ConditionalDetrForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    model.push_to_hub(repo_id=model_name, organization="DepuMeng", commit_message="Add model")
 | 
			
		||||
    # verify our conversion
 | 
			
		||||
    original_outputs = conditional_detr(pixel_values)
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
    assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
 | 
			
		||||
    assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
 | 
			
		||||
    if is_panoptic:
 | 
			
		||||
        assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    # Save model and image processor
 | 
			
		||||
    logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="conditional_detr_resnet50",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the CONDITIONAL_DETR model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_conditional_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,57 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvBERT checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from transformers import ConvBertConfig, ConvBertModel, TFConvBertModel, load_tf_weights_in_convbert
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_orig_tf1_checkpoint_to_pytorch(tf_checkpoint_path, convbert_config_file, pytorch_dump_path):
 | 
			
		||||
    conf = ConvBertConfig.from_json_file(convbert_config_file)
 | 
			
		||||
    model = ConvBertModel(conf)
 | 
			
		||||
 | 
			
		||||
    model = load_tf_weights_in_convbert(model, conf, tf_checkpoint_path)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
    tf_model = TFConvBertModel.from_pretrained(pytorch_dump_path, from_pt=True)
 | 
			
		||||
    tf_model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--convbert_config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help=(
 | 
			
		||||
            "The config json file corresponding to the pre-trained ConvBERT model. \n"
 | 
			
		||||
            "This specifies the model architecture."
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_orig_tf1_checkpoint_to_pytorch(args.tf_checkpoint_path, args.convbert_config_file, args.pytorch_dump_path)
 | 
			
		||||
@ -1,242 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvNext checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/facebookresearch/ConvNeXt"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_convnext_config(checkpoint_url):
 | 
			
		||||
    config = ConvNextConfig()
 | 
			
		||||
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 9, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "small" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [128, 256, 512, 1024]
 | 
			
		||||
    if "large" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [192, 384, 768, 1536]
 | 
			
		||||
    if "xlarge" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [256, 512, 1024, 2048]
 | 
			
		||||
 | 
			
		||||
    if "1k" in checkpoint_url:
 | 
			
		||||
        num_labels = 1000
 | 
			
		||||
        filename = "imagenet-1k-id2label.json"
 | 
			
		||||
        expected_shape = (1, 1000)
 | 
			
		||||
    else:
 | 
			
		||||
        num_labels = 21841
 | 
			
		||||
        filename = "imagenet-22k-id2label.json"
 | 
			
		||||
        expected_shape = (1, 21841)
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    config.num_labels = num_labels
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    if "1k" not in checkpoint_url:
 | 
			
		||||
        # this dataset contains 21843 labels but the model only has 21841
 | 
			
		||||
        # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
 | 
			
		||||
        del id2label[9205]
 | 
			
		||||
        del id2label[15027]
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    config.hidden_sizes = hidden_sizes
 | 
			
		||||
    config.depths = depths
 | 
			
		||||
 | 
			
		||||
    return config, expected_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "downsample_layers.0.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
 | 
			
		||||
    if "downsample_layers.0.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on
 | 
			
		||||
    if "downsample_layers.1.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.1.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.2.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.2.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.3.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.3.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
 | 
			
		||||
    if "stages" in name and "downsampling_layer" not in name:
 | 
			
		||||
        # stages.0.0. for instance should be renamed to stages.0.layers.0.
 | 
			
		||||
        name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
 | 
			
		||||
    if "stages" in name:
 | 
			
		||||
        name = name.replace("stages", "encoder.stages")
 | 
			
		||||
    if "norm" in name:
 | 
			
		||||
        name = name.replace("norm", "layernorm")
 | 
			
		||||
    if "gamma" in name:
 | 
			
		||||
        name = name.replace("gamma", "layer_scale_parameter")
 | 
			
		||||
    if "head" in name:
 | 
			
		||||
        name = name.replace("head", "classifier")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ConvNext structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define ConvNext configuration based on URL
 | 
			
		||||
    config, expected_shape = get_convnext_config(checkpoint_url)
 | 
			
		||||
    # load original state_dict from URL
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val
 | 
			
		||||
    # add prefix to all keys expect classifier head
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if not key.startswith("classifier"):
 | 
			
		||||
            key = "convnext." + key
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = ConvNextForImageClassification(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image, prepared by ConvNextImageProcessor
 | 
			
		||||
    size = 224 if "224" in checkpoint_url else 384
 | 
			
		||||
    image_processor = ConvNextImageProcessor(size=size)
 | 
			
		||||
    pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    logits = model(pixel_values).logits
 | 
			
		||||
 | 
			
		||||
    # note: the logits below were obtained without center cropping
 | 
			
		||||
    if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.4525, 0.7539, 0.0308])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.3561, 0.6350, -0.0384])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.4174, -0.0989, 0.1489])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.2513, -0.1349, -0.1613])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.2980, 0.3631, -0.1198])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.2963, 0.1227, 0.1723])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([1.7956, 0.8390, 0.2820])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth":
 | 
			
		||||
        expected_logits = torch.tensor([0.2681, 0.2365, 0.6246])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth":
 | 
			
		||||
        expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown URL: {checkpoint_url}")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3)
 | 
			
		||||
    assert logits.shape == expected_shape
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    print(f"Saving model to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Saving image processor to {pytorch_dump_folder_path}")
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    print("Pushing model to the hub...")
 | 
			
		||||
    model_name = "convnext"
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        model_name += "-tiny"
 | 
			
		||||
    elif "small" in checkpoint_url:
 | 
			
		||||
        model_name += "-small"
 | 
			
		||||
    elif "base" in checkpoint_url:
 | 
			
		||||
        model_name += "-base"
 | 
			
		||||
    elif "xlarge" in checkpoint_url:
 | 
			
		||||
        model_name += "-xlarge"
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        model_name += "-large"
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        model_name += "-224"
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        model_name += "-384"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" not in checkpoint_url:
 | 
			
		||||
        model_name += "-22k"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-22k-1k"
 | 
			
		||||
 | 
			
		||||
    model.push_to_hub(
 | 
			
		||||
        repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
 | 
			
		||||
        organization="nielsr",
 | 
			
		||||
        commit_message="Add model",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL of the original ConvNeXT checkpoint you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,286 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert ConvNeXTV2 checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/facebookresearch/ConvNeXt"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification
 | 
			
		||||
from transformers.image_utils import PILImageResampling
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_convnextv2_config(checkpoint_url):
 | 
			
		||||
    config = ConvNextV2Config()
 | 
			
		||||
 | 
			
		||||
    if "atto" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [40, 80, 160, 320]
 | 
			
		||||
    if "femto" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [48, 96, 192, 384]
 | 
			
		||||
    if "pico" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 6, 2]
 | 
			
		||||
        hidden_sizes = [64, 128, 256, 512]
 | 
			
		||||
    if "nano" in checkpoint_url:
 | 
			
		||||
        depths = [2, 2, 8, 2]
 | 
			
		||||
        hidden_sizes = [80, 160, 320, 640]
 | 
			
		||||
    if "tiny" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 9, 3]
 | 
			
		||||
        hidden_sizes = [96, 192, 384, 768]
 | 
			
		||||
    if "base" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [128, 256, 512, 1024]
 | 
			
		||||
    if "large" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [192, 384, 768, 1536]
 | 
			
		||||
    if "huge" in checkpoint_url:
 | 
			
		||||
        depths = [3, 3, 27, 3]
 | 
			
		||||
        hidden_sizes = [352, 704, 1408, 2816]
 | 
			
		||||
 | 
			
		||||
    num_labels = 1000
 | 
			
		||||
    filename = "imagenet-1k-id2label.json"
 | 
			
		||||
    expected_shape = (1, 1000)
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    config.num_labels = num_labels
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    config.hidden_sizes = hidden_sizes
 | 
			
		||||
    config.depths = depths
 | 
			
		||||
 | 
			
		||||
    return config, expected_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(name):
 | 
			
		||||
    if "downsample_layers.0.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
 | 
			
		||||
    if "downsample_layers.0.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.0.1", "embeddings.norm")  # we rename to layernorm later on
 | 
			
		||||
    if "downsample_layers.1.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.1.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.2.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.2.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
 | 
			
		||||
    if "downsample_layers.3.0" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
 | 
			
		||||
    if "downsample_layers.3.1" in name:
 | 
			
		||||
        name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
 | 
			
		||||
    if "stages" in name and "downsampling_layer" not in name:
 | 
			
		||||
        # stages.0.0. for instance should be renamed to stages.0.layers.0.
 | 
			
		||||
        name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
 | 
			
		||||
    if "gamma" in name:
 | 
			
		||||
        name = name.replace("gamma", "weight")
 | 
			
		||||
    if "beta" in name:
 | 
			
		||||
        name = name.replace("beta", "bias")
 | 
			
		||||
    if "stages" in name:
 | 
			
		||||
        name = name.replace("stages", "encoder.stages")
 | 
			
		||||
    if "norm" in name:
 | 
			
		||||
        name = name.replace("norm", "layernorm")
 | 
			
		||||
    if "head" in name:
 | 
			
		||||
        name = name.replace("head", "classifier")
 | 
			
		||||
 | 
			
		||||
    return name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_preprocessor(checkpoint_url):
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        size = 224
 | 
			
		||||
        crop_pct = 224 / 256
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        size = 384
 | 
			
		||||
        crop_pct = None
 | 
			
		||||
    else:
 | 
			
		||||
        size = 512
 | 
			
		||||
        crop_pct = None
 | 
			
		||||
 | 
			
		||||
    return ConvNextImageProcessor(
 | 
			
		||||
        size=size,
 | 
			
		||||
        crop_pct=crop_pct,
 | 
			
		||||
        image_mean=[0.485, 0.456, 0.406],
 | 
			
		||||
        image_std=[0.229, 0.224, 0.225],
 | 
			
		||||
        resample=PILImageResampling.BICUBIC,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
 | 
			
		||||
    """
 | 
			
		||||
    print("Downloading original model from checkpoint...")
 | 
			
		||||
    # define ConvNeXTV2 configuration based on URL
 | 
			
		||||
    config, expected_shape = get_convnextv2_config(checkpoint_url)
 | 
			
		||||
    # load original state_dict from URL
 | 
			
		||||
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
 | 
			
		||||
 | 
			
		||||
    print("Converting model parameters...")
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val
 | 
			
		||||
    # add prefix to all keys expect classifier head
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        if not key.startswith("classifier"):
 | 
			
		||||
            key = "convnextv2." + key
 | 
			
		||||
        state_dict[key] = val
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = ConvNextV2ForImageClassification(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image, prepared by ConvNextImageProcessor
 | 
			
		||||
    preprocessor = convert_preprocessor(checkpoint_url)
 | 
			
		||||
    inputs = preprocessor(images=prepare_img(), return_tensors="pt")
 | 
			
		||||
    logits = model(**inputs).logits
 | 
			
		||||
 | 
			
		||||
    # note: the logits below were obtained without center cropping
 | 
			
		||||
    if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
 | 
			
		||||
    elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
 | 
			
		||||
        expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown URL: {checkpoint_url}")
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
 | 
			
		||||
    assert logits.shape == expected_shape
 | 
			
		||||
    print("Model outputs match the original results!")
 | 
			
		||||
 | 
			
		||||
    if save_model:
 | 
			
		||||
        print("Saving model to local...")
 | 
			
		||||
        # Create folder to save model
 | 
			
		||||
        if not os.path.isdir(pytorch_dump_folder_path):
 | 
			
		||||
            os.mkdir(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        preprocessor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    model_name = "convnextv2"
 | 
			
		||||
    if "atto" in checkpoint_url:
 | 
			
		||||
        model_name += "-atto"
 | 
			
		||||
    if "femto" in checkpoint_url:
 | 
			
		||||
        model_name += "-femto"
 | 
			
		||||
    if "pico" in checkpoint_url:
 | 
			
		||||
        model_name += "-pico"
 | 
			
		||||
    if "nano" in checkpoint_url:
 | 
			
		||||
        model_name += "-nano"
 | 
			
		||||
    elif "tiny" in checkpoint_url:
 | 
			
		||||
        model_name += "-tiny"
 | 
			
		||||
    elif "base" in checkpoint_url:
 | 
			
		||||
        model_name += "-base"
 | 
			
		||||
    elif "large" in checkpoint_url:
 | 
			
		||||
        model_name += "-large"
 | 
			
		||||
    elif "huge" in checkpoint_url:
 | 
			
		||||
        model_name += "-huge"
 | 
			
		||||
    if "22k" in checkpoint_url and "1k" not in checkpoint_url:
 | 
			
		||||
        model_name += "-22k"
 | 
			
		||||
    elif "22k" in checkpoint_url and "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-22k-1k"
 | 
			
		||||
    elif "1k" in checkpoint_url:
 | 
			
		||||
        model_name += "-1k"
 | 
			
		||||
    if "224" in checkpoint_url:
 | 
			
		||||
        model_name += "-224"
 | 
			
		||||
    elif "384" in checkpoint_url:
 | 
			
		||||
        model_name += "-384"
 | 
			
		||||
    elif "512" in checkpoint_url:
 | 
			
		||||
        model_name += "-512"
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print(f"Pushing {model_name} to the hub...")
 | 
			
		||||
        model.push_to_hub(model_name)
 | 
			
		||||
        preprocessor.push_to_hub(model_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_url",
 | 
			
		||||
        default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default="model",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the output PyTorch model directory.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--save_model", action="store_true", help="Save model to local")
 | 
			
		||||
    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_convnextv2_checkpoint(
 | 
			
		||||
        args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
 | 
			
		||||
    )
 | 
			
		||||
@ -1,362 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert CvT checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/microsoft/CvT"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def embeddings(idx):
 | 
			
		||||
    """
 | 
			
		||||
    The function helps in renaming embedding layer weights.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        idx: stage number in original model
 | 
			
		||||
    """
 | 
			
		||||
    embed = []
 | 
			
		||||
    embed.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
 | 
			
		||||
            f"stage{idx}.patch_embed.proj.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    embed.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
 | 
			
		||||
            f"stage{idx}.patch_embed.proj.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    embed.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
 | 
			
		||||
            f"stage{idx}.patch_embed.norm.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    embed.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
 | 
			
		||||
            f"stage{idx}.patch_embed.norm.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    return embed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention(idx, cnt):
 | 
			
		||||
    """
 | 
			
		||||
    The function helps in renaming attention block layers weights.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        idx: stage number in original model
 | 
			
		||||
        cnt: count of blocks in each stage
 | 
			
		||||
    """
 | 
			
		||||
    attention_weights = []
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj.weight",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (
 | 
			
		||||
            f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
 | 
			
		||||
            f"stage{idx}.blocks.{cnt}.attn.proj.bias",
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
 | 
			
		||||
    )
 | 
			
		||||
    attention_weights.append(
 | 
			
		||||
        (f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
 | 
			
		||||
    )
 | 
			
		||||
    return attention_weights
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cls_token(idx):
 | 
			
		||||
    """
 | 
			
		||||
    Function helps in renaming cls_token weights
 | 
			
		||||
    """
 | 
			
		||||
    token = []
 | 
			
		||||
    token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
 | 
			
		||||
    return token
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def final():
 | 
			
		||||
    """
 | 
			
		||||
    Function helps in renaming final classification layer
 | 
			
		||||
    """
 | 
			
		||||
    head = []
 | 
			
		||||
    head.append(("layernorm.weight", "norm.weight"))
 | 
			
		||||
    head.append(("layernorm.bias", "norm.bias"))
 | 
			
		||||
    head.append(("classifier.weight", "head.weight"))
 | 
			
		||||
    head.append(("classifier.bias", "head.bias"))
 | 
			
		||||
    return head
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
 | 
			
		||||
    """
 | 
			
		||||
    Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
 | 
			
		||||
    """
 | 
			
		||||
    img_labels_file = "imagenet-1k-id2label.json"
 | 
			
		||||
    num_labels = 1000
 | 
			
		||||
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    num_labels = num_labels
 | 
			
		||||
    id2label = json.loads(Path(hf_hub_download(repo_id, img_labels_file, repo_type="dataset")).read_text())
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    id2label = id2label
 | 
			
		||||
    label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
 | 
			
		||||
 | 
			
		||||
    # For depth size 13 (13 = 1+2+10)
 | 
			
		||||
    if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
 | 
			
		||||
        config.depth = [1, 2, 10]
 | 
			
		||||
 | 
			
		||||
    # For depth size 21 (21 = 1+4+16)
 | 
			
		||||
    elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
 | 
			
		||||
        config.depth = [1, 4, 16]
 | 
			
		||||
 | 
			
		||||
    # For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
 | 
			
		||||
    else:
 | 
			
		||||
        config.depth = [2, 2, 20]
 | 
			
		||||
        config.num_heads = [3, 12, 16]
 | 
			
		||||
        config.embed_dim = [192, 768, 1024]
 | 
			
		||||
 | 
			
		||||
    model = CvtForImageClassification(config)
 | 
			
		||||
    image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
 | 
			
		||||
    image_processor.size["shortest_edge"] = image_size
 | 
			
		||||
    original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    huggingface_weights = OrderedDict()
 | 
			
		||||
    list_of_state_dict = []
 | 
			
		||||
 | 
			
		||||
    for idx in range(len(config.depth)):
 | 
			
		||||
        if config.cls_token[idx]:
 | 
			
		||||
            list_of_state_dict = list_of_state_dict + cls_token(idx)
 | 
			
		||||
        list_of_state_dict = list_of_state_dict + embeddings(idx)
 | 
			
		||||
        for cnt in range(config.depth[idx]):
 | 
			
		||||
            list_of_state_dict = list_of_state_dict + attention(idx, cnt)
 | 
			
		||||
 | 
			
		||||
    list_of_state_dict = list_of_state_dict + final()
 | 
			
		||||
    for gg in list_of_state_dict:
 | 
			
		||||
        print(gg)
 | 
			
		||||
    for i in range(len(list_of_state_dict)):
 | 
			
		||||
        huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(huggingface_weights)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder)
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--cvt_model",
 | 
			
		||||
        default="cvt-w24",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the cvt model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--image_size",
 | 
			
		||||
        default=384,
 | 
			
		||||
        type=int,
 | 
			
		||||
        help="Input Image Size",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--cvt_file_name",
 | 
			
		||||
        default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Input Image Size",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,233 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert DAB-DETR checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import gc
 | 
			
		||||
import json
 | 
			
		||||
import re
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
 | 
			
		||||
from transformers import ConditionalDetrImageProcessor, DabDetrConfig, DabDetrForObjectDetection
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
 | 
			
		||||
    # convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
 | 
			
		||||
    # for dab-DETR, also convert reference point head and query scale MLP
 | 
			
		||||
    r"input_proj\.(bias|weight)": r"input_projection.\1",
 | 
			
		||||
    r"refpoint_embed\.weight": r"query_refpoint_embeddings.weight",
 | 
			
		||||
    r"class_embed\.(bias|weight)": r"class_embed.\1",
 | 
			
		||||
    # negative lookbehind because of the overlap
 | 
			
		||||
    r"(?<!transformer\.decoder\.)bbox_embed\.layers\.(\d+)\.(bias|weight)": r"bbox_predictor.layers.\1.\2",
 | 
			
		||||
    r"transformer\.encoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"encoder.query_scale.layers.\1.\2",
 | 
			
		||||
    r"transformer\.decoder\.bbox_embed\.layers\.(\d+)\.(bias|weight)": r"decoder.bbox_embed.layers.\1.\2",
 | 
			
		||||
    r"transformer\.decoder\.norm\.(bias|weight)": r"decoder.layernorm.\1",
 | 
			
		||||
    r"transformer\.decoder\.ref_point_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_point_head.layers.\1.\2",
 | 
			
		||||
    r"transformer\.decoder\.ref_anchor_head\.layers\.(\d+)\.(bias|weight)": r"decoder.ref_anchor_head.layers.\1.\2",
 | 
			
		||||
    r"transformer\.decoder\.query_scale\.layers\.(\d+)\.(bias|weight)": r"decoder.query_scale.layers.\1.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.0\.ca_qpos_proj\.(bias|weight)": r"decoder.layers.0.cross_attn.cross_attn_query_pos_proj.\1",
 | 
			
		||||
    # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + activation function
 | 
			
		||||
    # output projection
 | 
			
		||||
    r"transformer\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"encoder.layers.\1.self_attn.out_proj.\2",
 | 
			
		||||
    # FFN layers
 | 
			
		||||
    r"transformer\.encoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"encoder.layers.\1.fc\2.\3",
 | 
			
		||||
    # normalization layers
 | 
			
		||||
    # nm1
 | 
			
		||||
    r"transformer\.encoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"encoder.layers.\1.self_attn_layer_norm.\2",
 | 
			
		||||
    # nm2
 | 
			
		||||
    r"transformer\.encoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"encoder.layers.\1.final_layer_norm.\2",
 | 
			
		||||
    # activation function weight
 | 
			
		||||
    r"transformer\.encoder\.layers\.(\d+)\.activation\.weight": r"encoder.layers.\1.activation_fn.weight",
 | 
			
		||||
    #########################################################################################################################################
 | 
			
		||||
    # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + activiation function weight
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn.output_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.cross_attn\.out_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn.output_proj.\2",
 | 
			
		||||
    # FFNs
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.linear(\d)\.(bias|weight)": r"decoder.layers.\1.mlp.fc\2.\3",
 | 
			
		||||
    # nm1
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.norm1\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_layer_norm.\2",
 | 
			
		||||
    # nm2
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.norm2\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_layer_norm.\2",
 | 
			
		||||
    # nm3
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.norm3\.(bias|weight)": r"decoder.layers.\1.mlp.final_layer_norm.\2",
 | 
			
		||||
    # activation function weight
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.activation\.weight": r"decoder.layers.\1.mlp.activation_fn.weight",
 | 
			
		||||
    # q, k, v projections and biases in self-attention in decoder
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.sa_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_content_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.sa_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_content_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.sa_qpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_query_pos_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.sa_kpos_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_key_pos_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.sa_v_proj\.(bias|weight)": r"decoder.layers.\1.self_attn.self_attn_value_proj.\2",
 | 
			
		||||
    # q, k, v projections in cross-attention in decoder
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.ca_qcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_content_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.ca_kcontent_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_content_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.ca_kpos_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_key_pos_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.ca_v_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_value_proj.\2",
 | 
			
		||||
    r"transformer\.decoder\.layers\.(\d+)\.ca_qpos_sine_proj\.(bias|weight)": r"decoder.layers.\1.cross_attn.cross_attn_query_pos_sine_proj.\2",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.mllama.convert_mllama_weights_to_hf.convert_old_keys_to_new_keys
 | 
			
		||||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
 | 
			
		||||
    """
 | 
			
		||||
    This function should be applied only once, on the concatenated keys to efficiently rename using
 | 
			
		||||
    the key mappings.
 | 
			
		||||
    """
 | 
			
		||||
    output_dict = {}
 | 
			
		||||
    if state_dict_keys is not None:
 | 
			
		||||
        old_text = "\n".join(state_dict_keys)
 | 
			
		||||
        new_text = old_text
 | 
			
		||||
        for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
 | 
			
		||||
            if replacement is None:
 | 
			
		||||
                new_text = re.sub(pattern, "", new_text)  # an empty line
 | 
			
		||||
                continue
 | 
			
		||||
            new_text = re.sub(pattern, replacement, new_text)
 | 
			
		||||
        output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
 | 
			
		||||
    return output_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    logger.info("Converting image processor...")
 | 
			
		||||
    format = "coco_detection"
 | 
			
		||||
    image_processor = ConditionalDetrImageProcessor(format=format)
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        image_processor.push_to_hub(repo_id=model_name, commit_message="Add new image processor")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    # load modified config. Why? After loading the default config, the backbone kwargs are already set.
 | 
			
		||||
    if "dc5" in model_name:
 | 
			
		||||
        config = DabDetrConfig(dilation=True)
 | 
			
		||||
    else:
 | 
			
		||||
        # load default config
 | 
			
		||||
        config = DabDetrConfig()
 | 
			
		||||
    # set other attributes
 | 
			
		||||
    if "dab-detr-resnet-50-dc5" == model_name:
 | 
			
		||||
        config.temperature_height = 10
 | 
			
		||||
        config.temperature_width = 10
 | 
			
		||||
    if "fixxy" in model_name:
 | 
			
		||||
        config.random_refpoints_xy = True
 | 
			
		||||
    if "pat3" in model_name:
 | 
			
		||||
        config.num_patterns = 3
 | 
			
		||||
        # only when the number of patterns (num_patterns parameter in config) are more than 0 like r50-pat3 or r50dc5-pat3
 | 
			
		||||
        ORIGINAL_TO_CONVERTED_KEY_MAPPING.update({r"transformer.patterns.weight": r"patterns.weight"})
 | 
			
		||||
 | 
			
		||||
    config.num_labels = 91
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "coco-detection-id2label.json"
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    # load original model from local path
 | 
			
		||||
    loaded = torch.load(pretrained_model_weights_path, map_location=torch.device("cpu"))["model"]
 | 
			
		||||
    # Renaming the original model state dictionary to HF compatibile
 | 
			
		||||
    all_keys = list(loaded.keys())
 | 
			
		||||
    new_keys = convert_old_keys_to_new_keys(all_keys)
 | 
			
		||||
    state_dict = {}
 | 
			
		||||
    for key in all_keys:
 | 
			
		||||
        if "backbone.0.body" in key:
 | 
			
		||||
            new_key = key.replace("backbone.0.body", "backbone.conv_encoder.model._backbone")
 | 
			
		||||
            state_dict[new_key] = loaded[key]
 | 
			
		||||
        # Q, K, V encoder values mapping
 | 
			
		||||
        elif re.search("self_attn.in_proj_(weight|bias)", key):
 | 
			
		||||
            # Dynamically find the layer number
 | 
			
		||||
            pattern = r"layers\.(\d+)\.self_attn\.in_proj_(weight|bias)"
 | 
			
		||||
            match = re.search(pattern, key)
 | 
			
		||||
            if match:
 | 
			
		||||
                layer_num = match.group(1)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Pattern not found in key: {key}")
 | 
			
		||||
 | 
			
		||||
            in_proj_value = loaded.pop(key)
 | 
			
		||||
            if "weight" in key:
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.weight"] = in_proj_value[:256, :]
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.weight"] = in_proj_value[256:512, :]
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.weight"] = in_proj_value[-256:, :]
 | 
			
		||||
            elif "bias" in key:
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.q_proj.bias"] = in_proj_value[:256]
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.k_proj.bias"] = in_proj_value[256:512]
 | 
			
		||||
                state_dict[f"encoder.layers.{layer_num}.self_attn.v_proj.bias"] = in_proj_value[-256:]
 | 
			
		||||
        else:
 | 
			
		||||
            new_key = new_keys[key]
 | 
			
		||||
            state_dict[new_key] = loaded[key]
 | 
			
		||||
 | 
			
		||||
    del loaded
 | 
			
		||||
    gc.collect()
 | 
			
		||||
    # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
 | 
			
		||||
    prefix = "model."
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if not key.startswith("class_embed") and not key.startswith("bbox_predictor"):
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[prefix + key] = val
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = DabDetrForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
    logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model.push_to_hub(repo_id=model_name, commit_message="Add new model")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_dab_detr_checkpoint(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    logger.info("Converting image processor...")
 | 
			
		||||
    write_image_processor(model_name, pytorch_dump_folder_path, push_to_hub)
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Converting model {model_name}...")
 | 
			
		||||
    write_model(model_name, pretrained_model_weights_path, pytorch_dump_folder_path, push_to_hub)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        default="dab-detr-resnet-50",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the DAB_DETR model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pretrained_model_weights_path",
 | 
			
		||||
        default="modelzoo/R50/checkpoint.pth",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="The path of the original model weights like: modelzoo/checkpoint.pth",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default="DAB_DETR", type=str, help="Path to the folder to output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub",
 | 
			
		||||
        default=True,
 | 
			
		||||
        type=bool,
 | 
			
		||||
        help="Whether to upload the converted weights and image processor config to the HuggingFace model profile. Default is set to false.",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_dab_detr_checkpoint(
 | 
			
		||||
        args.model_name, args.pretrained_model_weights_path, args.pytorch_dump_folder_path, args.push_to_hub
 | 
			
		||||
    )
 | 
			
		||||
@ -1,261 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import argparse
 | 
			
		||||
import fnmatch
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    DacConfig,
 | 
			
		||||
    DacFeatureExtractor,
 | 
			
		||||
    DacModel,
 | 
			
		||||
    logging,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# checkpoints downloaded using:
 | 
			
		||||
# pip install descript-audio-codec
 | 
			
		||||
# python3 -m dac download # downloads the default 44kHz variant
 | 
			
		||||
# python3 -m dac download --model_type 44khz # downloads the 44kHz variant
 | 
			
		||||
# python3 -m dac download --model_type 24khz # downloads the 24kHz variant
 | 
			
		||||
# python3 -m dac download --model_type 16khz # downloads the 16kHz variant
 | 
			
		||||
# More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger("transformers.models.dac")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def match_pattern(string, pattern):
 | 
			
		||||
    # Split the pattern into parts
 | 
			
		||||
    pattern_parts = pattern.split(".")
 | 
			
		||||
    string_parts = string.split(".")
 | 
			
		||||
 | 
			
		||||
    pattern_block_count = string_block_count = 0
 | 
			
		||||
 | 
			
		||||
    for part in pattern_parts:
 | 
			
		||||
        if part.startswith("block"):
 | 
			
		||||
            pattern_block_count += 1
 | 
			
		||||
 | 
			
		||||
    for part in string_parts:
 | 
			
		||||
        if part.startswith("block"):
 | 
			
		||||
            string_block_count += 1
 | 
			
		||||
 | 
			
		||||
    return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TOP_LEVEL_KEYS = []
 | 
			
		||||
IGNORE_KEYS = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAPPING_ENCODER = {
 | 
			
		||||
    "encoder.block.0": ["encoder.conv1"],
 | 
			
		||||
    "encoder.block.5": ["encoder.snake1"],
 | 
			
		||||
    "encoder.block.6": ["encoder.conv2"],
 | 
			
		||||
    "encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
 | 
			
		||||
    "encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
 | 
			
		||||
    "encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
 | 
			
		||||
    "encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
 | 
			
		||||
    "encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
 | 
			
		||||
    "encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MAPPING_QUANTIZER = {
 | 
			
		||||
    "quantizer.quantizers.*": ["quantizer.quantizers.*"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MAPPING_DECODER = {
 | 
			
		||||
    "decoder.model.0": ["decoder.conv1"],
 | 
			
		||||
    "decoder.model.5": ["decoder.snake1"],
 | 
			
		||||
    "decoder.model.6": ["decoder.conv2"],
 | 
			
		||||
    "decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
 | 
			
		||||
    "decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
 | 
			
		||||
    "decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
 | 
			
		||||
    "decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
 | 
			
		||||
    "decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
 | 
			
		||||
    "decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAPPING = {
 | 
			
		||||
    **MAPPING_ENCODER,
 | 
			
		||||
    **MAPPING_QUANTIZER,
 | 
			
		||||
    **MAPPING_DECODER,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
 | 
			
		||||
    for attribute in key.split("."):
 | 
			
		||||
        hf_pointer = getattr(hf_pointer, attribute)
 | 
			
		||||
 | 
			
		||||
    if weight_type is not None:
 | 
			
		||||
        hf_shape = getattr(hf_pointer, weight_type).shape
 | 
			
		||||
    else:
 | 
			
		||||
        hf_shape = hf_pointer.shape
 | 
			
		||||
 | 
			
		||||
    if hf_shape != value.shape:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
 | 
			
		||||
            f" {value.shape} for {full_name}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if weight_type == "weight":
 | 
			
		||||
        hf_pointer.weight.data = value
 | 
			
		||||
    elif weight_type == "weight_g":
 | 
			
		||||
        hf_pointer.weight_g.data = value
 | 
			
		||||
    elif weight_type == "weight_v":
 | 
			
		||||
        hf_pointer.weight_v.data = value
 | 
			
		||||
    elif weight_type == "bias":
 | 
			
		||||
        hf_pointer.bias.data = value
 | 
			
		||||
    elif weight_type == "alpha":
 | 
			
		||||
        hf_pointer.alpha.data = value
 | 
			
		||||
    logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def should_ignore(name, ignore_keys):
 | 
			
		||||
    for key in ignore_keys:
 | 
			
		||||
        if key.endswith(".*"):
 | 
			
		||||
            if name.startswith(key[:-1]):
 | 
			
		||||
                return True
 | 
			
		||||
        elif ".*." in key:
 | 
			
		||||
            prefix, suffix = key.split(".*.")
 | 
			
		||||
            if prefix in name and suffix in name:
 | 
			
		||||
                return True
 | 
			
		||||
        elif key in name:
 | 
			
		||||
            return True
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def recursively_load_weights(orig_dict, hf_model, model_name):
 | 
			
		||||
    unused_weights = []
 | 
			
		||||
 | 
			
		||||
    if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
 | 
			
		||||
        raise ValueError(f"Unsupported model: {model_name}")
 | 
			
		||||
 | 
			
		||||
    for name, value in orig_dict.items():
 | 
			
		||||
        is_used = False
 | 
			
		||||
        for key, mapped_key in MAPPING.items():
 | 
			
		||||
            regex = re.compile(key)
 | 
			
		||||
            if regex.search(name):
 | 
			
		||||
                if len(mapped_key) == 1:
 | 
			
		||||
                    if mapped_key[0][0] == "q":
 | 
			
		||||
                        mapped_key = ".".join(name.split(".")[:-1])
 | 
			
		||||
                    else:
 | 
			
		||||
                        mapped_key = mapped_key[0]
 | 
			
		||||
                elif len(mapped_key) == 3:
 | 
			
		||||
                    integers = re.findall(r"\b\d+\b", name)
 | 
			
		||||
                    if mapped_key[0][0] == "d":
 | 
			
		||||
                        mapped_key = "{}.{}.{}{}.{}".format(
 | 
			
		||||
                            mapped_key[0],
 | 
			
		||||
                            str(int(integers[0]) - 1),
 | 
			
		||||
                            mapped_key[1],
 | 
			
		||||
                            str(int(integers[1]) - 1),
 | 
			
		||||
                            mapped_key[2],
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        mapped_key = "{}.{}.{}{}.{}".format(
 | 
			
		||||
                            mapped_key[0],
 | 
			
		||||
                            str(int(integers[0]) - 1),
 | 
			
		||||
                            mapped_key[1],
 | 
			
		||||
                            str(int(integers[1]) + 1),
 | 
			
		||||
                            mapped_key[2],
 | 
			
		||||
                        )
 | 
			
		||||
                elif len(mapped_key) == 2:
 | 
			
		||||
                    integers = re.findall(r"\b\d+\b", name)
 | 
			
		||||
                    mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
 | 
			
		||||
 | 
			
		||||
                is_used = True
 | 
			
		||||
                if "weight_g" in name:
 | 
			
		||||
                    weight_type = "weight_g"
 | 
			
		||||
                elif "weight_v" in name:
 | 
			
		||||
                    weight_type = "weight_v"
 | 
			
		||||
                elif "bias" in name:
 | 
			
		||||
                    weight_type = "bias"
 | 
			
		||||
                elif "alpha" in name:
 | 
			
		||||
                    weight_type = "alpha"
 | 
			
		||||
                elif "weight" in name:
 | 
			
		||||
                    weight_type = "weight"
 | 
			
		||||
                set_recursively(hf_model, mapped_key, value, name, weight_type)
 | 
			
		||||
 | 
			
		||||
        if not is_used:
 | 
			
		||||
            unused_weights.append(name)
 | 
			
		||||
 | 
			
		||||
    print(list(set(unused_weights)))
 | 
			
		||||
 | 
			
		||||
    logger.warning(f"Unused weights: {unused_weights}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_checkpoint(
 | 
			
		||||
    model_name,
 | 
			
		||||
    checkpoint_path,
 | 
			
		||||
    pytorch_dump_folder_path,
 | 
			
		||||
    sample_rate=16000,
 | 
			
		||||
    repo_id=None,
 | 
			
		||||
):
 | 
			
		||||
    model_dict = torch.load(checkpoint_path, "cpu")
 | 
			
		||||
 | 
			
		||||
    config = DacConfig()
 | 
			
		||||
 | 
			
		||||
    metadata = model_dict["metadata"]["kwargs"]
 | 
			
		||||
    config.encoder_hidden_size = metadata["encoder_dim"]
 | 
			
		||||
    config.downsampling_ratios = metadata["encoder_rates"]
 | 
			
		||||
    config.codebook_size = metadata["codebook_size"]
 | 
			
		||||
    config.n_codebooks = metadata["n_codebooks"]
 | 
			
		||||
    config.codebook_dim = metadata["codebook_dim"]
 | 
			
		||||
    config.decoder_hidden_size = metadata["decoder_dim"]
 | 
			
		||||
    config.upsampling_ratios = metadata["decoder_rates"]
 | 
			
		||||
    config.quantizer_dropout = float(metadata["quantizer_dropout"])
 | 
			
		||||
    config.sampling_rate = sample_rate
 | 
			
		||||
 | 
			
		||||
    model = DacModel(config)
 | 
			
		||||
    feature_extractor = DacFeatureExtractor()
 | 
			
		||||
    feature_extractor.sampling_rate = sample_rate
 | 
			
		||||
 | 
			
		||||
    original_checkpoint = model_dict["state_dict"]
 | 
			
		||||
 | 
			
		||||
    model.apply_weight_norm()
 | 
			
		||||
    recursively_load_weights(original_checkpoint, model, model_name)
 | 
			
		||||
    model.remove_weight_norm()
 | 
			
		||||
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if repo_id:
 | 
			
		||||
        print("Pushing to the hub...")
 | 
			
		||||
        feature_extractor.push_to_hub(repo_id)
 | 
			
		||||
        model.push_to_hub(repo_id)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model",
 | 
			
		||||
        default="dac_44khz",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    convert_checkpoint(
 | 
			
		||||
        args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
 | 
			
		||||
    )
 | 
			
		||||
@ -1,285 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Wav2Vec2 checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
from functools import reduce
 | 
			
		||||
 | 
			
		||||
import fairseq
 | 
			
		||||
import torch
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
 | 
			
		||||
from transformers import Wav2Vec2Processor, logging
 | 
			
		||||
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig
 | 
			
		||||
 | 
			
		||||
# Copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
 | 
			
		||||
from transformers.models.data2vec.data2vec_audio import Data2VecAudioModel as Dummy  # noqa: F401
 | 
			
		||||
from transformers.models.data2vec.modeling_data2vec_audio import Data2VecAudioForCTC, Data2VecAudioModel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
MAPPING = {
 | 
			
		||||
    "post_extract_proj": "feature_projection.projection",
 | 
			
		||||
    "models.0.layer_norm": "feature_projection.layer_norm",
 | 
			
		||||
    "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
 | 
			
		||||
    "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
 | 
			
		||||
    "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
 | 
			
		||||
    "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
 | 
			
		||||
    "self_attn_layer_norm": "encoder.layers.*.layer_norm",
 | 
			
		||||
    "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
 | 
			
		||||
    "fc2": "encoder.layers.*.feed_forward.output_dense",
 | 
			
		||||
    "final_layer_norm": "encoder.layers.*.final_layer_norm",
 | 
			
		||||
    "encoder.layer_norm": "encoder.layer_norm",
 | 
			
		||||
    "w2v_model.layer_norm": "feature_projection.layer_norm",
 | 
			
		||||
    "w2v_encoder.proj": "lm_head",
 | 
			
		||||
    "mask_emb": "masked_spec_embed",
 | 
			
		||||
}
 | 
			
		||||
TOP_LEVEL_KEYS = [
 | 
			
		||||
    "lm_head",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
 | 
			
		||||
    for attribute in key.split("."):
 | 
			
		||||
        hf_pointer = getattr(hf_pointer, attribute)
 | 
			
		||||
 | 
			
		||||
    if weight_type is not None:
 | 
			
		||||
        hf_shape = getattr(hf_pointer, weight_type).shape
 | 
			
		||||
    else:
 | 
			
		||||
        hf_shape = hf_pointer.shape
 | 
			
		||||
 | 
			
		||||
    if hf_shape != value.shape:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
 | 
			
		||||
            f" {value.shape} for {full_name}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if weight_type == "weight":
 | 
			
		||||
        hf_pointer.weight.data = value
 | 
			
		||||
    elif weight_type == "weight_g":
 | 
			
		||||
        hf_pointer.weight_g.data = value
 | 
			
		||||
    elif weight_type == "weight_v":
 | 
			
		||||
        hf_pointer.weight_v.data = value
 | 
			
		||||
    elif weight_type == "bias":
 | 
			
		||||
        hf_pointer.bias.data = value
 | 
			
		||||
    else:
 | 
			
		||||
        hf_pointer.data = value
 | 
			
		||||
 | 
			
		||||
    logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def recursively_load_weights(fairseq_model, hf_model, is_headless):
 | 
			
		||||
    unused_weights = []
 | 
			
		||||
    fairseq_dict = fairseq_model.state_dict()
 | 
			
		||||
 | 
			
		||||
    if not is_headless:
 | 
			
		||||
        feature_extractor = hf_model.data2vec_audio.feature_extractor
 | 
			
		||||
        pos_conv_embedding = hf_model.data2vec_audio.encoder.pos_conv_embed
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        feature_extractor = hf_model.feature_extractor
 | 
			
		||||
        pos_conv_embedding = hf_model.encoder.pos_conv_embed
 | 
			
		||||
 | 
			
		||||
    for name, value in fairseq_dict.items():
 | 
			
		||||
        is_used = False
 | 
			
		||||
        if "conv_layers" in name:
 | 
			
		||||
            load_conv_layer(
 | 
			
		||||
                name,
 | 
			
		||||
                value,
 | 
			
		||||
                feature_extractor,
 | 
			
		||||
                unused_weights,
 | 
			
		||||
            )
 | 
			
		||||
            is_used = True
 | 
			
		||||
        elif "pos_conv" in name:
 | 
			
		||||
            load_pos_conv_layer(
 | 
			
		||||
                name,
 | 
			
		||||
                value,
 | 
			
		||||
                pos_conv_embedding,
 | 
			
		||||
                unused_weights,
 | 
			
		||||
            )
 | 
			
		||||
            is_used = True
 | 
			
		||||
        else:
 | 
			
		||||
            for key, mapped_key in MAPPING.items():
 | 
			
		||||
                if not is_headless:
 | 
			
		||||
                    mapped_key = "data2vec_audio." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
 | 
			
		||||
                if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
 | 
			
		||||
                    is_used = True
 | 
			
		||||
                    if "*" in mapped_key:
 | 
			
		||||
                        layer_index = name.split(key)[0].split(".")[-2]
 | 
			
		||||
                        mapped_key = mapped_key.replace("*", layer_index)
 | 
			
		||||
                    if "weight_g" in name:
 | 
			
		||||
                        weight_type = "weight_g"
 | 
			
		||||
                    elif "weight_v" in name:
 | 
			
		||||
                        weight_type = "weight_v"
 | 
			
		||||
                    elif "bias" in name:
 | 
			
		||||
                        weight_type = "bias"
 | 
			
		||||
                    elif "weight" in name:
 | 
			
		||||
                        # TODO: don't match quantizer.weight_proj
 | 
			
		||||
                        weight_type = "weight"
 | 
			
		||||
                    else:
 | 
			
		||||
                        weight_type = None
 | 
			
		||||
                    set_recursively(hf_model, mapped_key, value, name, weight_type)
 | 
			
		||||
                continue
 | 
			
		||||
        if not is_used:
 | 
			
		||||
            unused_weights.append(name)
 | 
			
		||||
 | 
			
		||||
    logger.warning(f"Unused weights: {unused_weights}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def access_by_string(module, path):
 | 
			
		||||
    names = path.split(".")
 | 
			
		||||
    return reduce(getattr, names, module)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_weights(full_name, module, fsq_value, hf_weight_path):
 | 
			
		||||
    hf_weight = access_by_string(module, hf_weight_path)
 | 
			
		||||
    hf_value = hf_weight.data
 | 
			
		||||
 | 
			
		||||
    if fsq_value.shape != hf_value.shape:
 | 
			
		||||
        raise ValueError(f"{full_name} has size {fsq_value.shape}, but {hf_value.shape} was found.")
 | 
			
		||||
    hf_weight.data = fsq_value
 | 
			
		||||
    logger.info(f"{full_name} was correctly initialized from {hf_weight_path}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights):
 | 
			
		||||
    name = full_name.split("conv_layers.")[-1]
 | 
			
		||||
    items = name.split(".")
 | 
			
		||||
    layer_id = int(items[0])
 | 
			
		||||
    type_id = int(items[1])
 | 
			
		||||
 | 
			
		||||
    weight_type = name.split(".")[-1]
 | 
			
		||||
    if type_id == 0:
 | 
			
		||||
        layer_type = "conv"
 | 
			
		||||
    elif type_id == 2:
 | 
			
		||||
        layer_type = "layer_norm"
 | 
			
		||||
    else:
 | 
			
		||||
        unused_weights.append(full_name)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    set_weights(full_name, feature_extractor, value, f"conv_layers.{layer_id}.{layer_type}.{weight_type}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_pos_conv_layer(full_name, value, pos_conv_embeddings, unused_weights):
 | 
			
		||||
    name = full_name.split("pos_conv.")[-1]
 | 
			
		||||
    items = name.split(".")
 | 
			
		||||
    layer_id = int(items[0])
 | 
			
		||||
    type_id = int(items[1])
 | 
			
		||||
 | 
			
		||||
    weight_type = name.split(".")[-1]
 | 
			
		||||
    if type_id != 0:
 | 
			
		||||
        unused_weights.append(full_name)
 | 
			
		||||
        return
 | 
			
		||||
    else:
 | 
			
		||||
        layer_type = "conv"
 | 
			
		||||
 | 
			
		||||
    set_weights(full_name, pos_conv_embeddings, value, f"layers.{layer_id}.{layer_type}.{weight_type}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_wav2vec2_checkpoint(
 | 
			
		||||
    checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to transformers design.
 | 
			
		||||
    """
 | 
			
		||||
    if config_path is not None:
 | 
			
		||||
        config = Data2VecAudioConfig.from_pretrained(config_path)
 | 
			
		||||
    else:
 | 
			
		||||
        config = Data2VecAudioConfig()
 | 
			
		||||
 | 
			
		||||
    if not is_finetuned:
 | 
			
		||||
        # Modify final_proj layer name
 | 
			
		||||
        hf_wav2vec = Data2VecAudioModel(config)
 | 
			
		||||
        data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
 | 
			
		||||
 | 
			
		||||
        state_dict = torch.load(checkpoint_path, weights_only=True)
 | 
			
		||||
        state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
 | 
			
		||||
        state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
 | 
			
		||||
        converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
 | 
			
		||||
        torch.save(state_dict, converted_ckpt)
 | 
			
		||||
    else:
 | 
			
		||||
        hf_wav2vec = Data2VecAudioForCTC(config)
 | 
			
		||||
        converted_ckpt = checkpoint_path
 | 
			
		||||
 | 
			
		||||
    def load_data2vec(path):
 | 
			
		||||
        model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([path])
 | 
			
		||||
        return model[0].eval()
 | 
			
		||||
 | 
			
		||||
    model = load_data2vec(converted_ckpt)
 | 
			
		||||
 | 
			
		||||
    recursively_load_weights(model, hf_wav2vec, not is_finetuned)
 | 
			
		||||
 | 
			
		||||
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
 | 
			
		||||
 | 
			
		||||
    ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
 | 
			
		||||
    input_audio = [x["array"] for x in ds[:4]["audio"]]
 | 
			
		||||
 | 
			
		||||
    inputs = processor(input_audio, return_tensors="pt", padding=True)
 | 
			
		||||
 | 
			
		||||
    input_values = inputs.input_values
 | 
			
		||||
    attention_mask = inputs.attention_mask
 | 
			
		||||
    #    input_values = inputs.input_values[:, :-1]
 | 
			
		||||
    #    attention_mask = inputs.attention_mask[:, :-1]
 | 
			
		||||
 | 
			
		||||
    hf_wav2vec.eval()
 | 
			
		||||
    model.eval()
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
 | 
			
		||||
            "encoder_out"
 | 
			
		||||
        ].transpose(0, 1)
 | 
			
		||||
        our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["logits"]
 | 
			
		||||
 | 
			
		||||
        pred_ids = torch.argmax(our_output, dim=-1)
 | 
			
		||||
        output_string = processor.batch_decode(pred_ids)
 | 
			
		||||
 | 
			
		||||
        print(f"Expected Output: {ds[:4]['text']}, Pred: {output_string}")
 | 
			
		||||
    else:
 | 
			
		||||
        their_output = model(source=input_values, padding_mask=(1 - attention_mask), mask=False, features_only=True)[
 | 
			
		||||
            "layer_results"
 | 
			
		||||
        ][-1][0].transpose(0, 1)
 | 
			
		||||
        our_output = hf_wav2vec(input_values, attention_mask=attention_mask)["last_hidden_state"]
 | 
			
		||||
 | 
			
		||||
    print(our_output.shape, their_output.shape)
 | 
			
		||||
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
 | 
			
		||||
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
 | 
			
		||||
    success = torch.allclose(our_output, their_output, atol=1e-3)
 | 
			
		||||
    print("Do both models output the same tensors?", "🔥" if success else "💩")
 | 
			
		||||
    if not success:
 | 
			
		||||
        raise Exception("Something went wRoNg")
 | 
			
		||||
 | 
			
		||||
    hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    else:
 | 
			
		||||
        processor.feature_extractor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
 | 
			
		||||
    parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
 | 
			
		||||
    parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
 | 
			
		||||
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_wav2vec2_checkpoint(
 | 
			
		||||
        args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
 | 
			
		||||
    )
 | 
			
		||||
@ -1,207 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert data2vec checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import pathlib
 | 
			
		||||
 | 
			
		||||
import fairseq
 | 
			
		||||
import torch
 | 
			
		||||
from fairseq.modules import TransformerSentenceEncoderLayer
 | 
			
		||||
from packaging import version
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    Data2VecTextConfig,
 | 
			
		||||
    Data2VecTextForMaskedLM,
 | 
			
		||||
    Data2VecTextForSequenceClassification,
 | 
			
		||||
    Data2VecTextModel,
 | 
			
		||||
)
 | 
			
		||||
from transformers.models.bert.modeling_bert import (
 | 
			
		||||
    BertIntermediate,
 | 
			
		||||
    BertLayer,
 | 
			
		||||
    BertOutput,
 | 
			
		||||
    BertSelfAttention,
 | 
			
		||||
    BertSelfOutput,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# IMPORTANT: In order for this script to run, please make sure to download the dictionary: `dict.txt` from wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
 | 
			
		||||
# File copied from https://github.com/pytorch/fairseq/blob/main/examples/data2vec/models/data2vec_text.py
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
 | 
			
		||||
    raise Exception("requires fairseq >= 0.9.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_data2vec_checkpoint_to_pytorch(
 | 
			
		||||
    data2vec_checkpoint_path: str, pytorch_dump_folder_path: str, classification_head: bool
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak data2vec's weights to our BERT structure.
 | 
			
		||||
    """
 | 
			
		||||
    data2vec_checkpoint_dir, data2vec_checkpoint_file_name = os.path.split(data2vec_checkpoint_path)
 | 
			
		||||
    data2vec = Data2VecTextModel.from_pretrained(
 | 
			
		||||
        data2vec_checkpoint_dir, checkpoint_file=data2vec_checkpoint_file_name
 | 
			
		||||
    )
 | 
			
		||||
    data2vec.eval()  # disable dropout
 | 
			
		||||
    data2vec_model = data2vec.models[0]
 | 
			
		||||
    data2vec_sent_encoder = data2vec_model.encoder.sentence_encoder
 | 
			
		||||
    config = Data2VecTextConfig(
 | 
			
		||||
        vocab_size=data2vec_sent_encoder.embed_tokens.num_embeddings,
 | 
			
		||||
        hidden_size=data2vec_model.args.encoder_embed_dim,
 | 
			
		||||
        num_hidden_layers=data2vec_model.args.encoder_layers,
 | 
			
		||||
        num_attention_heads=data2vec_model.args.encoder_attention_heads,
 | 
			
		||||
        intermediate_size=data2vec_model.args.encoder_ffn_embed_dim,
 | 
			
		||||
        max_position_embeddings=514,
 | 
			
		||||
        type_vocab_size=1,
 | 
			
		||||
        layer_norm_eps=1e-5,  # PyTorch default used in fairseq
 | 
			
		||||
    )
 | 
			
		||||
    if classification_head:
 | 
			
		||||
        config.num_labels = data2vec.model.classification_heads["mnli"].out_proj.weight.shape[0]
 | 
			
		||||
    print("Our BERT config:", config)
 | 
			
		||||
 | 
			
		||||
    model = Data2VecTextForSequenceClassification(config) if classification_head else Data2VecTextForMaskedLM(config)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Now let's copy all the weights.
 | 
			
		||||
    # Embeddings
 | 
			
		||||
    model.data2vec_text.embeddings.word_embeddings.weight = data2vec_sent_encoder.embed_tokens.weight
 | 
			
		||||
    model.data2vec_text.embeddings.position_embeddings.weight = data2vec_sent_encoder.embed_positions.weight
 | 
			
		||||
    model.data2vec_text.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
 | 
			
		||||
        model.data2vec_text.embeddings.token_type_embeddings.weight
 | 
			
		||||
    )  # just zero them out b/c data2vec doesn't use them.
 | 
			
		||||
    model.data2vec_text.embeddings.LayerNorm.weight = data2vec_sent_encoder.layernorm_embedding.weight
 | 
			
		||||
    model.data2vec_text.embeddings.LayerNorm.bias = data2vec_sent_encoder.layernorm_embedding.bias
 | 
			
		||||
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        # Encoder: start of layer
 | 
			
		||||
        layer: BertLayer = model.data2vec_text.encoder.layer[i]
 | 
			
		||||
        data2vec_layer: TransformerSentenceEncoderLayer = data2vec_sent_encoder.layers[i]
 | 
			
		||||
 | 
			
		||||
        # self attention
 | 
			
		||||
        self_attn: BertSelfAttention = layer.attention.self
 | 
			
		||||
        assert data2vec_layer.self_attn.k_proj.weight.data.shape == torch.Size(
 | 
			
		||||
            (config.hidden_size, config.hidden_size)
 | 
			
		||||
        ), (
 | 
			
		||||
            "Shape for data2vec_layer.self_attn.k_proj.weight.data should be"
 | 
			
		||||
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
 | 
			
		||||
        )
 | 
			
		||||
        assert data2vec_layer.self_attn.q_proj.weight.data.shape == torch.Size(
 | 
			
		||||
            (config.hidden_size, config.hidden_size)
 | 
			
		||||
        ), (
 | 
			
		||||
            "Shape for data2vec_layer.self_attn.q_proj.weight.data should be"
 | 
			
		||||
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
 | 
			
		||||
        )
 | 
			
		||||
        assert data2vec_layer.self_attn.v_proj.weight.data.shape == torch.Size(
 | 
			
		||||
            (config.hidden_size, config.hidden_size)
 | 
			
		||||
        ), (
 | 
			
		||||
            "Shape for data2vec_layer.self_attn.v_proj.weight.data should be"
 | 
			
		||||
            f" {torch.Size((config.hidden_size, config.hidden_size))}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self_attn.query.weight.data = data2vec_layer.self_attn.q_proj.weight
 | 
			
		||||
        self_attn.query.bias.data = data2vec_layer.self_attn.q_proj.bias
 | 
			
		||||
        self_attn.key.weight.data = data2vec_layer.self_attn.k_proj.weight
 | 
			
		||||
        self_attn.key.bias.data = data2vec_layer.self_attn.k_proj.bias
 | 
			
		||||
        self_attn.value.weight.data = data2vec_layer.self_attn.v_proj.weight
 | 
			
		||||
        self_attn.value.bias.data = data2vec_layer.self_attn.v_proj.bias
 | 
			
		||||
 | 
			
		||||
        # self-attention output
 | 
			
		||||
        self_output: BertSelfOutput = layer.attention.output
 | 
			
		||||
        assert self_output.dense.weight.shape == data2vec_layer.self_attn.out_proj.weight.shape, (
 | 
			
		||||
            f"Shape for self_output.dense.weight should be {data2vec_layer.self_attn.out_proj.weight.shape}"
 | 
			
		||||
        )
 | 
			
		||||
        self_output.dense.weight = data2vec_layer.self_attn.out_proj.weight
 | 
			
		||||
        self_output.dense.bias = data2vec_layer.self_attn.out_proj.bias
 | 
			
		||||
        self_output.LayerNorm.weight = data2vec_layer.self_attn_layer_norm.weight
 | 
			
		||||
        self_output.LayerNorm.bias = data2vec_layer.self_attn_layer_norm.bias
 | 
			
		||||
 | 
			
		||||
        # intermediate
 | 
			
		||||
        intermediate: BertIntermediate = layer.intermediate
 | 
			
		||||
        assert intermediate.dense.weight.shape == data2vec_layer.fc1.weight.shape, (
 | 
			
		||||
            f"Shape for intermediate.dense.weight should be {data2vec_layer.fc1.weight.shape}"
 | 
			
		||||
        )
 | 
			
		||||
        intermediate.dense.weight = data2vec_layer.fc1.weight
 | 
			
		||||
        intermediate.dense.bias = data2vec_layer.fc1.bias
 | 
			
		||||
 | 
			
		||||
        # output
 | 
			
		||||
        bert_output: BertOutput = layer.output
 | 
			
		||||
        assert bert_output.dense.weight.shape == data2vec_layer.fc2.weight.shape, (
 | 
			
		||||
            f"Shape for bert_output.dense.weight should be {data2vec_layer.fc2.weight.shape}"
 | 
			
		||||
        )
 | 
			
		||||
        bert_output.dense.weight = data2vec_layer.fc2.weight
 | 
			
		||||
        bert_output.dense.bias = data2vec_layer.fc2.bias
 | 
			
		||||
        bert_output.LayerNorm.weight = data2vec_layer.final_layer_norm.weight
 | 
			
		||||
        bert_output.LayerNorm.bias = data2vec_layer.final_layer_norm.bias
 | 
			
		||||
        # end of layer
 | 
			
		||||
 | 
			
		||||
    if classification_head:
 | 
			
		||||
        model.classifier.dense.weight = data2vec.model.classification_heads["mnli"].dense.weight
 | 
			
		||||
        model.classifier.dense.bias = data2vec.model.classification_heads["mnli"].dense.bias
 | 
			
		||||
        model.classifier.out_proj.weight = data2vec.model.classification_heads["mnli"].out_proj.weight
 | 
			
		||||
        model.classifier.out_proj.bias = data2vec.model.classification_heads["mnli"].out_proj.bias
 | 
			
		||||
    else:
 | 
			
		||||
        # LM Head
 | 
			
		||||
        model.lm_head.dense.weight = data2vec_model.encoder.lm_head.dense.weight
 | 
			
		||||
        model.lm_head.dense.bias = data2vec_model.encoder.lm_head.dense.bias
 | 
			
		||||
        model.lm_head.layer_norm.weight = data2vec_model.encoder.lm_head.layer_norm.weight
 | 
			
		||||
        model.lm_head.layer_norm.bias = data2vec_model.encoder.lm_head.layer_norm.bias
 | 
			
		||||
        model.lm_head.decoder.weight = data2vec_model.encoder.lm_head.weight
 | 
			
		||||
        model.lm_head.decoder.bias = data2vec_model.encoder.lm_head.bias
 | 
			
		||||
 | 
			
		||||
    # Let's check that we get the same results.
 | 
			
		||||
    input_ids: torch.Tensor = data2vec.encode(SAMPLE_TEXT).unsqueeze(0)  # batch of size 1
 | 
			
		||||
 | 
			
		||||
    our_output = model(input_ids)[0]
 | 
			
		||||
    if classification_head:
 | 
			
		||||
        their_output = data2vec.model.classification_heads["mnli"](data2vec.extract_features(input_ids))
 | 
			
		||||
    else:
 | 
			
		||||
        their_output = data2vec_model(input_ids)[0]
 | 
			
		||||
    print(our_output.shape, their_output.shape)
 | 
			
		||||
    max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
 | 
			
		||||
    print(f"max_absolute_diff = {max_absolute_diff}")  # ~ 1e-7
 | 
			
		||||
    success = torch.allclose(our_output, their_output, atol=1e-3)
 | 
			
		||||
    print("Do both models output the same tensors?", "🔥" if success else "💩")
 | 
			
		||||
    if not success:
 | 
			
		||||
        raise Exception("Something went wRoNg")
 | 
			
		||||
 | 
			
		||||
    pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
 | 
			
		||||
    print(f"Saving model to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--classification_head", action="store_true", help="Whether to convert a final classification head."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_data2vec_checkpoint_to_pytorch(
 | 
			
		||||
        args.checkpoint_path, args.pytorch_dump_folder_path, args.classification_head
 | 
			
		||||
    )
 | 
			
		||||
@ -1,374 +0,0 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from timm.models import create_model
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    BeitImageProcessor,
 | 
			
		||||
    Data2VecVisionConfig,
 | 
			
		||||
    Data2VecVisionForImageClassification,
 | 
			
		||||
    Data2VecVisionModel,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec."):
 | 
			
		||||
    prefix = "backbone." if is_semantic else ""
 | 
			
		||||
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.norm1.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_before.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_before.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.weight", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.attn.proj.bias", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.norm2.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_after.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_after.bias"))
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.mlp.fc1.weight", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.weight")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append(
 | 
			
		||||
            (f"{prefix}blocks.{i}.mlp.fc1.bias", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.bias")
 | 
			
		||||
        )
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"{hf_prefix}encoder.layer.{i}.output.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"{hf_prefix}encoder.layer.{i}.output.dense.bias"))
 | 
			
		||||
 | 
			
		||||
    # projection layer + position embeddings
 | 
			
		||||
    rename_keys.extend(
 | 
			
		||||
        [
 | 
			
		||||
            (f"{prefix}cls_token", f"{hf_prefix}embeddings.cls_token"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.weight", f"{hf_prefix}embeddings.patch_embeddings.projection.weight"),
 | 
			
		||||
            (f"{prefix}patch_embed.proj.bias", f"{hf_prefix}embeddings.patch_embeddings.projection.bias"),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if has_lm_head:
 | 
			
		||||
        # mask token + shared relative position bias + layernorm
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("mask_token", f"{hf_prefix}embeddings.mask_token"),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_bias_table",
 | 
			
		||||
                    f"{hf_prefix}encoder.relative_position_bias.relative_position_bias_table",
 | 
			
		||||
                ),
 | 
			
		||||
                (
 | 
			
		||||
                    "rel_pos_bias.relative_position_index",
 | 
			
		||||
                    f"{hf_prefix}encoder.relative_position_bias.relative_position_index",
 | 
			
		||||
                ),
 | 
			
		||||
                ("norm.weight", "layernorm.weight"),
 | 
			
		||||
                ("norm.bias", "layernorm.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    elif is_semantic:
 | 
			
		||||
        # semantic segmentation classification heads
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
 | 
			
		||||
                ("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
 | 
			
		||||
                ("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        # layernorm + classification head
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("fc_norm.weight", f"{hf_prefix}pooler.layernorm.weight"),
 | 
			
		||||
                ("fc_norm.bias", f"{hf_prefix}pooler.layernorm.bias"),
 | 
			
		||||
                ("head.weight", "classifier.weight"),
 | 
			
		||||
                ("head.bias", "classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec_vision."):
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        prefix = "backbone." if is_semantic else ""
 | 
			
		||||
        # queries, keys and values
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
 | 
			
		||||
        q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
 | 
			
		||||
        v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
 | 
			
		||||
            : config.hidden_size, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
 | 
			
		||||
            config.hidden_size : config.hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
 | 
			
		||||
            -config.hidden_size :, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
 | 
			
		||||
 | 
			
		||||
        # gamma_1 and gamma_2
 | 
			
		||||
        # we call them lambda because otherwise they are renamed when using .from_pretrained
 | 
			
		||||
        gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
 | 
			
		||||
        gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
 | 
			
		||||
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_1"] = gamma_1
 | 
			
		||||
        state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_2"] = gamma_2
 | 
			
		||||
 | 
			
		||||
        # relative_position bias table + index
 | 
			
		||||
        if not has_lm_head:
 | 
			
		||||
            # each layer has its own relative position bias
 | 
			
		||||
            table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
 | 
			
		||||
            index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
 | 
			
		||||
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
 | 
			
		||||
            ] = table
 | 
			
		||||
            state_dict[
 | 
			
		||||
                f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
 | 
			
		||||
            ] = index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        "Convert Data2VecVision to HF for image classification and pretraining", add_help=False
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--hf_checkpoint_name", type=str)
 | 
			
		||||
    parser.add_argument("--input_size", default=224, type=int, help="images input size")
 | 
			
		||||
    parser.add_argument("--beit_checkpoint", default="", help="beit checkpoint")
 | 
			
		||||
 | 
			
		||||
    return parser.parse_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_beit_model(args, is_finetuned, is_large):
 | 
			
		||||
    def load_state_dict(model, state_dict, prefix="", ignore_missing="relative_position_index"):
 | 
			
		||||
        missing_keys = []
 | 
			
		||||
        unexpected_keys = []
 | 
			
		||||
        error_msgs = []
 | 
			
		||||
        # copy state_dict so _load_from_state_dict can modify it
 | 
			
		||||
        metadata = getattr(state_dict, "_metadata", None)
 | 
			
		||||
        state_dict = state_dict.copy()
 | 
			
		||||
        if metadata is not None:
 | 
			
		||||
            state_dict._metadata = metadata
 | 
			
		||||
 | 
			
		||||
        def load(module, prefix=""):
 | 
			
		||||
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
 | 
			
		||||
            module._load_from_state_dict(
 | 
			
		||||
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
 | 
			
		||||
            )
 | 
			
		||||
            for name, child in module._modules.items():
 | 
			
		||||
                if child is not None:
 | 
			
		||||
                    load(child, prefix + name + ".")
 | 
			
		||||
 | 
			
		||||
        load(model, prefix=prefix)
 | 
			
		||||
 | 
			
		||||
        warn_missing_keys = []
 | 
			
		||||
        ignore_missing_keys = []
 | 
			
		||||
        for key in missing_keys:
 | 
			
		||||
            keep_flag = True
 | 
			
		||||
            for ignore_key in ignore_missing.split("|"):
 | 
			
		||||
                if ignore_key in key:
 | 
			
		||||
                    keep_flag = False
 | 
			
		||||
                    break
 | 
			
		||||
            if keep_flag:
 | 
			
		||||
                warn_missing_keys.append(key)
 | 
			
		||||
            else:
 | 
			
		||||
                ignore_missing_keys.append(key)
 | 
			
		||||
 | 
			
		||||
        missing_keys = warn_missing_keys
 | 
			
		||||
 | 
			
		||||
        if len(missing_keys) > 0:
 | 
			
		||||
            print(
 | 
			
		||||
                "Weights of {} not initialized from pretrained model: {}".format(
 | 
			
		||||
                    model.__class__.__name__, missing_keys
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        if len(unexpected_keys) > 0:
 | 
			
		||||
            print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys))
 | 
			
		||||
        if len(ignore_missing_keys) > 0:
 | 
			
		||||
            print(
 | 
			
		||||
                "Ignored weights of {} not initialized from pretrained model: {}".format(
 | 
			
		||||
                    model.__class__.__name__, ignore_missing_keys
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        if len(error_msgs) > 0:
 | 
			
		||||
            print("\n".join(error_msgs))
 | 
			
		||||
 | 
			
		||||
    model_kwargs = {
 | 
			
		||||
        "pretrained": False,
 | 
			
		||||
        "use_shared_rel_pos_bias": True,
 | 
			
		||||
        "use_abs_pos_emb": False,
 | 
			
		||||
        "init_values": 0.1,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        model_kwargs.update(
 | 
			
		||||
            {
 | 
			
		||||
                "num_classes": 1000,
 | 
			
		||||
                "use_mean_pooling": True,
 | 
			
		||||
                "init_scale": 0.001,
 | 
			
		||||
                "use_rel_pos_bias": True,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    model = create_model(
 | 
			
		||||
        "beit_large_patch16_224" if is_large else "beit_base_patch16_224",
 | 
			
		||||
        **model_kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    patch_size = model.patch_embed.patch_size
 | 
			
		||||
    args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
 | 
			
		||||
    checkpoint = torch.load(args.beit_checkpoint, map_location="cpu")
 | 
			
		||||
 | 
			
		||||
    print(f"Load ckpt from {args.beit_checkpoint}")
 | 
			
		||||
    checkpoint_model = None
 | 
			
		||||
    for model_key in ("model", "module"):
 | 
			
		||||
        if model_key in checkpoint:
 | 
			
		||||
            checkpoint_model = checkpoint[model_key]
 | 
			
		||||
            print(f"Load state_dict by model_key = {model_key}")
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
    all_keys = list(checkpoint_model.keys())
 | 
			
		||||
    for key in all_keys:
 | 
			
		||||
        if "relative_position_index" in key:
 | 
			
		||||
            checkpoint_model.pop(key)
 | 
			
		||||
 | 
			
		||||
        if "relative_position_bias_table" in key:
 | 
			
		||||
            rel_pos_bias = checkpoint_model[key]
 | 
			
		||||
            src_num_pos, num_attn_heads = rel_pos_bias.size()
 | 
			
		||||
            dst_num_pos, _ = model.state_dict()[key].size()
 | 
			
		||||
            dst_patch_shape = model.patch_embed.patch_shape
 | 
			
		||||
            if dst_patch_shape[0] != dst_patch_shape[1]:
 | 
			
		||||
                raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    load_state_dict(model, checkpoint_model, prefix="")
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    args = get_args()
 | 
			
		||||
 | 
			
		||||
    is_finetuned = "ft1k" in args.hf_checkpoint_name
 | 
			
		||||
    is_large = "large" in args.hf_checkpoint_name
 | 
			
		||||
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        # To convert Beit's data2vec_vision to HF you need to copy
 | 
			
		||||
        # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py
 | 
			
		||||
        # into this folder.
 | 
			
		||||
        import modeling_finetune  # noqa: F401
 | 
			
		||||
    else:
 | 
			
		||||
        # To convert Beit's data2vec_vision to HF you need to copy
 | 
			
		||||
        # https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py
 | 
			
		||||
        # into this folder
 | 
			
		||||
        # IMPORTANT: Note that for now we've only converted the down-stream
 | 
			
		||||
        # model and not the full pretrained model. This means for the integration
 | 
			
		||||
        # test you need to add a `return x` after the following line:
 | 
			
		||||
        # https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197
 | 
			
		||||
        # to make the integration test pass.
 | 
			
		||||
        import modeling_cyclical  # noqa: F401
 | 
			
		||||
 | 
			
		||||
    # 1. Create model config
 | 
			
		||||
    config = Data2VecVisionConfig()
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        config.use_relative_position_bias = True
 | 
			
		||||
        config.use_shared_relative_position_bias = False
 | 
			
		||||
        config.use_mean_pooling = True
 | 
			
		||||
        config.num_labels = 1000
 | 
			
		||||
 | 
			
		||||
        repo_id = "huggingface/label-files"
 | 
			
		||||
        filename = "imagenet-1k-id2label.json"
 | 
			
		||||
        id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
        id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
        config.id2label = id2label
 | 
			
		||||
        config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    else:
 | 
			
		||||
        config.use_relative_position_bias = False
 | 
			
		||||
        config.use_shared_relative_position_bias = True
 | 
			
		||||
        config.use_mean_pooling = False
 | 
			
		||||
 | 
			
		||||
    if is_large:
 | 
			
		||||
        config.hidden_size = 1024
 | 
			
		||||
        config.intermediate_size = 4096
 | 
			
		||||
        config.num_hidden_layers = 24
 | 
			
		||||
        config.num_attention_heads = 16
 | 
			
		||||
 | 
			
		||||
    # 2. Load Beit model
 | 
			
		||||
    orig_model = load_beit_model(args, is_finetuned, is_large)
 | 
			
		||||
    orig_model.eval()
 | 
			
		||||
 | 
			
		||||
    # 3. Forward Beit model
 | 
			
		||||
    image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
 | 
			
		||||
    image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png")
 | 
			
		||||
    encoding = image_processor(images=image, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    orig_args = (pixel_values,) if is_finetuned else (pixel_values, None)
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        orig_model_output = orig_model(*orig_args)
 | 
			
		||||
 | 
			
		||||
    # 4. Load HF Data2VecVision model
 | 
			
		||||
    if is_finetuned:
 | 
			
		||||
        hf_model = Data2VecVisionForImageClassification(config)
 | 
			
		||||
        hf_model.eval()
 | 
			
		||||
        has_lm_head = False
 | 
			
		||||
        hf_prefix = "data2vec_vision."
 | 
			
		||||
    else:
 | 
			
		||||
        hf_model = Data2VecVisionModel(config)
 | 
			
		||||
        hf_model.eval()
 | 
			
		||||
        has_lm_head = True
 | 
			
		||||
        hf_prefix = ""
 | 
			
		||||
 | 
			
		||||
    rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
 | 
			
		||||
    state_dict = orig_model.state_dict()
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        val = state_dict.pop(src)
 | 
			
		||||
        state_dict[dest] = val
 | 
			
		||||
 | 
			
		||||
    read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head)
 | 
			
		||||
    missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False)
 | 
			
		||||
    print("HF missing", missing_keys)
 | 
			
		||||
    print("HF unexpected_keys", unexpected_keys)
 | 
			
		||||
 | 
			
		||||
    # 5. Forward HF Data2VecVision model
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        hf_model_output = hf_model(pixel_values)
 | 
			
		||||
 | 
			
		||||
    hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state
 | 
			
		||||
 | 
			
		||||
    # 6. Compare
 | 
			
		||||
    max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item()
 | 
			
		||||
 | 
			
		||||
    print(f"max_absolute_diff = {max_absolute_diff}")
 | 
			
		||||
    success = torch.allclose(hf_output, orig_model_output, atol=1e-3)
 | 
			
		||||
    print("Do both models output the same tensors?", "🔥" if success else "💩")
 | 
			
		||||
    if not success:
 | 
			
		||||
        raise Exception("Something went wRoNg")
 | 
			
		||||
 | 
			
		||||
    # 7. Save
 | 
			
		||||
    print(f"Saving to {args.hf_checkpoint_name}")
 | 
			
		||||
    hf_model.save_pretrained(args.hf_checkpoint_name)
 | 
			
		||||
    image_processor.save_pretrained(args.hf_checkpoint_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
    # Run the following to convert checkpoints
 | 
			
		||||
    #  python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
 | 
			
		||||
    #          --beit_checkpoint ./pretrained_base.pt \
 | 
			
		||||
    #          --hf_checkpoint_name "./data2vec-vision-base"
 | 
			
		||||
    #  python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
 | 
			
		||||
    #          --beit_checkpoint ./finetuned_base.pt \
 | 
			
		||||
    #          --hf_checkpoint_name "./data2vec-vision-base-ft1k"
 | 
			
		||||
    #  python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
 | 
			
		||||
    #          --beit_checkpoint ./pretrained_large.pt \
 | 
			
		||||
    #          --hf_checkpoint_name "./data2vec-vision-large"
 | 
			
		||||
    #  python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \
 | 
			
		||||
    #          --beit_checkpoint ./finetuned_large.pt \
 | 
			
		||||
    #          --hf_checkpoint_name "./data2vec-vision-large-ft1k"
 | 
			
		||||
@ -1,236 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Deformable DETR checkpoints."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import DeformableDetrConfig, DeformableDetrForObjectDetection, DeformableDetrImageProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(orig_key):
 | 
			
		||||
    if "backbone.0.body" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("backbone.0.body", "backbone.conv_encoder.model")
 | 
			
		||||
    if "transformer" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("transformer.", "")
 | 
			
		||||
    if "norm1" in orig_key:
 | 
			
		||||
        if "encoder" in orig_key:
 | 
			
		||||
            orig_key = orig_key.replace("norm1", "self_attn_layer_norm")
 | 
			
		||||
        else:
 | 
			
		||||
            orig_key = orig_key.replace("norm1", "encoder_attn_layer_norm")
 | 
			
		||||
    if "norm2" in orig_key:
 | 
			
		||||
        if "encoder" in orig_key:
 | 
			
		||||
            orig_key = orig_key.replace("norm2", "final_layer_norm")
 | 
			
		||||
        else:
 | 
			
		||||
            orig_key = orig_key.replace("norm2", "self_attn_layer_norm")
 | 
			
		||||
    if "norm3" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("norm3", "final_layer_norm")
 | 
			
		||||
    if "linear1" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("linear1", "fc1")
 | 
			
		||||
    if "linear2" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("linear2", "fc2")
 | 
			
		||||
    if "query_embed" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("query_embed", "query_position_embeddings")
 | 
			
		||||
    if "cross_attn" in orig_key:
 | 
			
		||||
        orig_key = orig_key.replace("cross_attn", "encoder_attn")
 | 
			
		||||
 | 
			
		||||
    return orig_key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_q_k_v(state_dict):
 | 
			
		||||
    # transformer decoder self-attention layers
 | 
			
		||||
    for i in range(6):
 | 
			
		||||
        # read in weights + bias of input projection layer of self-attention
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"decoder.layers.{i}.self_attn.in_proj_weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"decoder.layers.{i}.self_attn.in_proj_bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
 | 
			
		||||
        state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_deformable_detr_checkpoint(
 | 
			
		||||
    checkpoint_path,
 | 
			
		||||
    single_scale,
 | 
			
		||||
    dilation,
 | 
			
		||||
    with_box_refine,
 | 
			
		||||
    two_stage,
 | 
			
		||||
    pytorch_dump_folder_path,
 | 
			
		||||
    push_to_hub,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our Deformable DETR structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load default config
 | 
			
		||||
    config = DeformableDetrConfig()
 | 
			
		||||
    # set config attributes
 | 
			
		||||
    if single_scale:
 | 
			
		||||
        config.num_feature_levels = 1
 | 
			
		||||
    config.dilation = dilation
 | 
			
		||||
    config.with_box_refine = with_box_refine
 | 
			
		||||
    config.two_stage = two_stage
 | 
			
		||||
    # set labels
 | 
			
		||||
    config.num_labels = 91
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "coco-detection-id2label.json"
 | 
			
		||||
    id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    # load image processor
 | 
			
		||||
    image_processor = DeformableDetrImageProcessor(format="coco_detection")
 | 
			
		||||
 | 
			
		||||
    # prepare image
 | 
			
		||||
    img = prepare_img()
 | 
			
		||||
    encoding = image_processor(images=img, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
 | 
			
		||||
    logger.info("Converting model...")
 | 
			
		||||
 | 
			
		||||
    # load original state dict
 | 
			
		||||
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
 | 
			
		||||
    # rename keys
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        val = state_dict.pop(key)
 | 
			
		||||
        state_dict[rename_key(key)] = val
 | 
			
		||||
    # query, key and value matrices need special treatment
 | 
			
		||||
    read_in_q_k_v(state_dict)
 | 
			
		||||
    # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
 | 
			
		||||
    prefix = "model."
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if not key.startswith("class_embed") and not key.startswith("bbox_embed"):
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[prefix + key] = val
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = DeformableDetrForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    # verify our conversion
 | 
			
		||||
    outputs = model(pixel_values.to(device))
 | 
			
		||||
 | 
			
		||||
    expected_logits = torch.tensor(
 | 
			
		||||
        [[-9.6645, -4.3449, -5.8705], [-9.7035, -3.8504, -5.0724], [-10.5634, -5.3379, -7.5116]]
 | 
			
		||||
    )
 | 
			
		||||
    expected_boxes = torch.tensor([[0.8693, 0.2289, 0.2492], [0.3150, 0.5489, 0.5845], [0.5563, 0.7580, 0.8518]])
 | 
			
		||||
 | 
			
		||||
    if single_scale:
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-9.9051, -4.2541, -6.4852], [-9.6947, -4.0854, -6.8033], [-10.0665, -5.8470, -7.7003]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.7292, 0.4991, 0.5532], [0.7959, 0.2426, 0.4236], [0.7582, 0.3518, 0.4451]])
 | 
			
		||||
 | 
			
		||||
    if single_scale and dilation:
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-8.9652, -4.1074, -5.6635], [-9.0596, -4.9447, -6.6075], [-10.1178, -4.5275, -6.2671]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.7665, 0.4130, 0.4769], [0.8364, 0.1841, 0.3391], [0.6261, 0.3895, 0.7978]])
 | 
			
		||||
 | 
			
		||||
    if with_box_refine:
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-8.8895, -5.4187, -6.8153], [-8.4706, -6.1668, -7.6184], [-9.0042, -5.5359, -6.9141]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.7828, 0.2208, 0.4323], [0.0892, 0.5996, 0.1319], [0.5524, 0.6389, 0.8914]])
 | 
			
		||||
 | 
			
		||||
    if with_box_refine and two_stage:
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-6.7108, -4.3213, -6.3777], [-8.9014, -6.1799, -6.7240], [-6.9315, -4.4735, -6.2298]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.2583, 0.5499, 0.4683], [0.7652, 0.9068, 0.4882], [0.5490, 0.2763, 0.0564]])
 | 
			
		||||
 | 
			
		||||
    print("Logits:", outputs.logits[0, :3, :3])
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(outputs.logits[0, :3, :3], expected_logits.to(device), atol=1e-4)
 | 
			
		||||
    assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes.to(device), atol=1e-4)
 | 
			
		||||
 | 
			
		||||
    print("Everything ok!")
 | 
			
		||||
 | 
			
		||||
    # Save model and image processor
 | 
			
		||||
    logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    # Push to hub
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        model_name = "deformable-detr"
 | 
			
		||||
        model_name += "-single-scale" if single_scale else ""
 | 
			
		||||
        model_name += "-dc5" if dilation else ""
 | 
			
		||||
        model_name += "-with-box-refine" if with_box_refine else ""
 | 
			
		||||
        model_name += "-two-stage" if two_stage else ""
 | 
			
		||||
        print("Pushing model to hub...")
 | 
			
		||||
        model.push_to_hub(repo_path_or_name=model_name, organization="nielsr", commit_message="Add model")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--checkpoint_path",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="/home/niels/checkpoints/deformable_detr/r50_deformable_detr-checkpoint.pth",
 | 
			
		||||
        help="Path to Pytorch checkpoint (.pth file) you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--single_scale", action="store_true", help="Whether to set config.num_features_levels = 1.")
 | 
			
		||||
    parser.add_argument("--dilation", action="store_true", help="Whether to set config.dilation=True.")
 | 
			
		||||
    parser.add_argument("--with_box_refine", action="store_true", help="Whether to set config.with_box_refine=True.")
 | 
			
		||||
    parser.add_argument("--two_stage", action="store_true", help="Whether to set config.two_stage=True.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to the folder to output PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_deformable_detr_checkpoint(
 | 
			
		||||
        args.checkpoint_path,
 | 
			
		||||
        args.single_scale,
 | 
			
		||||
        args.dilation,
 | 
			
		||||
        args.with_box_refine,
 | 
			
		||||
        args.two_stage,
 | 
			
		||||
        args.pytorch_dump_folder_path,
 | 
			
		||||
        args.push_to_hub,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,218 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert DeiT distilled checkpoints from the timm library."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import timm
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import DeiTConfig, DeiTForImageClassificationWithTeacher, DeiTImageProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config, base_model=False):
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.norm1.weight", f"deit.encoder.layer.{i}.layernorm_before.weight"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.norm1.bias", f"deit.encoder.layer.{i}.layernorm_before.bias"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.attn.proj.weight", f"deit.encoder.layer.{i}.attention.output.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.attn.proj.bias", f"deit.encoder.layer.{i}.attention.output.dense.bias"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.norm2.weight", f"deit.encoder.layer.{i}.layernorm_after.weight"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.norm2.bias", f"deit.encoder.layer.{i}.layernorm_after.bias"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"deit.encoder.layer.{i}.intermediate.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"deit.encoder.layer.{i}.intermediate.dense.bias"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"deit.encoder.layer.{i}.output.dense.weight"))
 | 
			
		||||
        rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"deit.encoder.layer.{i}.output.dense.bias"))
 | 
			
		||||
 | 
			
		||||
    # projection layer + position embeddings
 | 
			
		||||
    rename_keys.extend(
 | 
			
		||||
        [
 | 
			
		||||
            ("cls_token", "deit.embeddings.cls_token"),
 | 
			
		||||
            ("dist_token", "deit.embeddings.distillation_token"),
 | 
			
		||||
            ("patch_embed.proj.weight", "deit.embeddings.patch_embeddings.projection.weight"),
 | 
			
		||||
            ("patch_embed.proj.bias", "deit.embeddings.patch_embeddings.projection.bias"),
 | 
			
		||||
            ("pos_embed", "deit.embeddings.position_embeddings"),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if base_model:
 | 
			
		||||
        # layernorm + pooler
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("norm.weight", "layernorm.weight"),
 | 
			
		||||
                ("norm.bias", "layernorm.bias"),
 | 
			
		||||
                ("pre_logits.fc.weight", "pooler.dense.weight"),
 | 
			
		||||
                ("pre_logits.fc.bias", "pooler.dense.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # if just the base model, we should remove "deit" from all keys that start with "deit"
 | 
			
		||||
        rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("deit") else pair for pair in rename_keys]
 | 
			
		||||
    else:
 | 
			
		||||
        # layernorm + classification heads
 | 
			
		||||
        rename_keys.extend(
 | 
			
		||||
            [
 | 
			
		||||
                ("norm.weight", "deit.layernorm.weight"),
 | 
			
		||||
                ("norm.bias", "deit.layernorm.bias"),
 | 
			
		||||
                ("head.weight", "cls_classifier.weight"),
 | 
			
		||||
                ("head.bias", "cls_classifier.bias"),
 | 
			
		||||
                ("head_dist.weight", "distillation_classifier.weight"),
 | 
			
		||||
                ("head_dist.bias", "distillation_classifier.bias"),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# we split up the matrix of each encoder layer into queries, keys and values
 | 
			
		||||
def read_in_q_k_v(state_dict, config, base_model=False):
 | 
			
		||||
    for i in range(config.num_hidden_layers):
 | 
			
		||||
        if base_model:
 | 
			
		||||
            prefix = ""
 | 
			
		||||
        else:
 | 
			
		||||
            prefix = "deit."
 | 
			
		||||
        # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
 | 
			
		||||
            : config.hidden_size, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
 | 
			
		||||
            config.hidden_size : config.hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
 | 
			
		||||
            config.hidden_size : config.hidden_size * 2
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
 | 
			
		||||
            -config.hidden_size :, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our DeiT structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # define default DeiT configuration
 | 
			
		||||
    config = DeiTConfig()
 | 
			
		||||
    # all deit models have fine-tuned heads
 | 
			
		||||
    base_model = False
 | 
			
		||||
    # dataset (fine-tuned on ImageNet 2012), patch_size and image_size
 | 
			
		||||
    config.num_labels = 1000
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "imagenet-1k-id2label.json"
 | 
			
		||||
    id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
    config.patch_size = int(deit_name[-6:-4])
 | 
			
		||||
    config.image_size = int(deit_name[-3:])
 | 
			
		||||
    # size of the architecture
 | 
			
		||||
    if deit_name[9:].startswith("tiny"):
 | 
			
		||||
        config.hidden_size = 192
 | 
			
		||||
        config.intermediate_size = 768
 | 
			
		||||
        config.num_hidden_layers = 12
 | 
			
		||||
        config.num_attention_heads = 3
 | 
			
		||||
    elif deit_name[9:].startswith("small"):
 | 
			
		||||
        config.hidden_size = 384
 | 
			
		||||
        config.intermediate_size = 1536
 | 
			
		||||
        config.num_hidden_layers = 12
 | 
			
		||||
        config.num_attention_heads = 6
 | 
			
		||||
    if deit_name[9:].startswith("base"):
 | 
			
		||||
        pass
 | 
			
		||||
    elif deit_name[4:].startswith("large"):
 | 
			
		||||
        config.hidden_size = 1024
 | 
			
		||||
        config.intermediate_size = 4096
 | 
			
		||||
        config.num_hidden_layers = 24
 | 
			
		||||
        config.num_attention_heads = 16
 | 
			
		||||
 | 
			
		||||
    # load original model from timm
 | 
			
		||||
    timm_model = timm.create_model(deit_name, pretrained=True)
 | 
			
		||||
    timm_model.eval()
 | 
			
		||||
 | 
			
		||||
    # load state_dict of original model, remove and rename some keys
 | 
			
		||||
    state_dict = timm_model.state_dict()
 | 
			
		||||
    rename_keys = create_rename_keys(config, base_model)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    read_in_q_k_v(state_dict, config, base_model)
 | 
			
		||||
 | 
			
		||||
    # load HuggingFace model
 | 
			
		||||
    model = DeiTForImageClassificationWithTeacher(config).eval()
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
 | 
			
		||||
    # Check outputs on an image, prepared by DeiTImageProcessor
 | 
			
		||||
    size = int(
 | 
			
		||||
        (256 / 224) * config.image_size
 | 
			
		||||
    )  # to maintain same ratio w.r.t. 224 images, see https://github.com/facebookresearch/deit/blob/ab5715372db8c6cad5740714b2216d55aeae052e/datasets.py#L103
 | 
			
		||||
    image_processor = DeiTImageProcessor(size=size, crop_size=config.image_size)
 | 
			
		||||
    encoding = image_processor(images=prepare_img(), return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
 | 
			
		||||
    timm_logits = timm_model(pixel_values)
 | 
			
		||||
    assert timm_logits.shape == outputs.logits.shape
 | 
			
		||||
    assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
 | 
			
		||||
 | 
			
		||||
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
    print(f"Saving model {deit_name} to {pytorch_dump_folder_path}")
 | 
			
		||||
    model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    print(f"Saving image processor to {pytorch_dump_folder_path}")
 | 
			
		||||
    image_processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--deit_name",
 | 
			
		||||
        default="vit_deit_base_distilled_patch16_224",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Name of the DeiT timm model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_deit_checkpoint(args.deit_name, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,318 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2020, The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert Bort checkpoint."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import gluonnlp as nlp
 | 
			
		||||
import mxnet as mx
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from gluonnlp.base import get_home_dir
 | 
			
		||||
from gluonnlp.model.bert import BERTEncoder
 | 
			
		||||
from gluonnlp.model.utils import _load_vocab
 | 
			
		||||
from gluonnlp.vocab import Vocab
 | 
			
		||||
from packaging import version
 | 
			
		||||
from torch import nn
 | 
			
		||||
 | 
			
		||||
from transformers import BertConfig, BertForMaskedLM, BertModel, RobertaTokenizer
 | 
			
		||||
from transformers.models.bert.modeling_bert import (
 | 
			
		||||
    BertIntermediate,
 | 
			
		||||
    BertLayer,
 | 
			
		||||
    BertOutput,
 | 
			
		||||
    BertSelfAttention,
 | 
			
		||||
    BertSelfOutput,
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if version.parse(nlp.__version__) != version.parse("0.8.3"):
 | 
			
		||||
    raise Exception("requires gluonnlp == 0.8.3")
 | 
			
		||||
 | 
			
		||||
if version.parse(mx.__version__) != version.parse("1.5.0"):
 | 
			
		||||
    raise Exception("requires mxnet == 1.5.0")
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
SAMPLE_TEXT = "The Nymphenburg Palace is a beautiful palace in Munich!"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_bort_checkpoint_to_pytorch(bort_checkpoint_path: str, pytorch_dump_folder_path: str):
 | 
			
		||||
    """
 | 
			
		||||
    Convert the original Bort checkpoint (based on MXNET and Gluonnlp) to our BERT structure-
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Original Bort configuration
 | 
			
		||||
    bort_4_8_768_1024_hparams = {
 | 
			
		||||
        "attention_cell": "multi_head",
 | 
			
		||||
        "num_layers": 4,
 | 
			
		||||
        "units": 1024,
 | 
			
		||||
        "hidden_size": 768,
 | 
			
		||||
        "max_length": 512,
 | 
			
		||||
        "num_heads": 8,
 | 
			
		||||
        "scaled": True,
 | 
			
		||||
        "dropout": 0.1,
 | 
			
		||||
        "use_residual": True,
 | 
			
		||||
        "embed_size": 1024,
 | 
			
		||||
        "embed_dropout": 0.1,
 | 
			
		||||
        "word_embed": None,
 | 
			
		||||
        "layer_norm_eps": 1e-5,
 | 
			
		||||
        "token_type_vocab_size": 2,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    predefined_args = bort_4_8_768_1024_hparams
 | 
			
		||||
 | 
			
		||||
    # Let's construct the original Bort model here
 | 
			
		||||
    # Taken from official BERT implementation, see:
 | 
			
		||||
    # https://github.com/alexa/bort/blob/master/bort/bort.py
 | 
			
		||||
    encoder = BERTEncoder(
 | 
			
		||||
        attention_cell=predefined_args["attention_cell"],
 | 
			
		||||
        num_layers=predefined_args["num_layers"],
 | 
			
		||||
        units=predefined_args["units"],
 | 
			
		||||
        hidden_size=predefined_args["hidden_size"],
 | 
			
		||||
        max_length=predefined_args["max_length"],
 | 
			
		||||
        num_heads=predefined_args["num_heads"],
 | 
			
		||||
        scaled=predefined_args["scaled"],
 | 
			
		||||
        dropout=predefined_args["dropout"],
 | 
			
		||||
        output_attention=False,
 | 
			
		||||
        output_all_encodings=False,
 | 
			
		||||
        use_residual=predefined_args["use_residual"],
 | 
			
		||||
        activation=predefined_args.get("activation", "gelu"),
 | 
			
		||||
        layer_norm_eps=predefined_args.get("layer_norm_eps", None),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Vocab information needs to be fetched first
 | 
			
		||||
    # It's the same as RoBERTa, so RobertaTokenizer can be used later
 | 
			
		||||
    vocab_name = "openwebtext_ccnews_stories_books_cased"
 | 
			
		||||
 | 
			
		||||
    # Specify download folder to Gluonnlp's vocab
 | 
			
		||||
    gluon_cache_dir = os.path.join(get_home_dir(), "models")
 | 
			
		||||
    bort_vocab = _load_vocab(vocab_name, None, gluon_cache_dir, cls=Vocab)
 | 
			
		||||
 | 
			
		||||
    original_bort = nlp.model.BERTModel(
 | 
			
		||||
        encoder,
 | 
			
		||||
        len(bort_vocab),
 | 
			
		||||
        units=predefined_args["units"],
 | 
			
		||||
        embed_size=predefined_args["embed_size"],
 | 
			
		||||
        embed_dropout=predefined_args["embed_dropout"],
 | 
			
		||||
        word_embed=predefined_args["word_embed"],
 | 
			
		||||
        use_pooler=False,
 | 
			
		||||
        use_token_type_embed=False,
 | 
			
		||||
        token_type_vocab_size=predefined_args["token_type_vocab_size"],
 | 
			
		||||
        use_classifier=False,
 | 
			
		||||
        use_decoder=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    original_bort.load_parameters(bort_checkpoint_path, cast_dtype=True, ignore_extra=True)
 | 
			
		||||
    params = original_bort._collect_params_with_prefix()
 | 
			
		||||
 | 
			
		||||
    # Build our config 🤗
 | 
			
		||||
    hf_bort_config_json = {
 | 
			
		||||
        "architectures": ["BertForMaskedLM"],
 | 
			
		||||
        "attention_probs_dropout_prob": predefined_args["dropout"],
 | 
			
		||||
        "hidden_act": "gelu",
 | 
			
		||||
        "hidden_dropout_prob": predefined_args["dropout"],
 | 
			
		||||
        "hidden_size": predefined_args["embed_size"],
 | 
			
		||||
        "initializer_range": 0.02,
 | 
			
		||||
        "intermediate_size": predefined_args["hidden_size"],
 | 
			
		||||
        "layer_norm_eps": predefined_args["layer_norm_eps"],
 | 
			
		||||
        "max_position_embeddings": predefined_args["max_length"],
 | 
			
		||||
        "model_type": "bort",
 | 
			
		||||
        "num_attention_heads": predefined_args["num_heads"],
 | 
			
		||||
        "num_hidden_layers": predefined_args["num_layers"],
 | 
			
		||||
        "pad_token_id": 1,  # 2 = BERT, 1 = RoBERTa
 | 
			
		||||
        "type_vocab_size": 1,  # 2 = BERT, 1 = RoBERTa
 | 
			
		||||
        "vocab_size": len(bort_vocab),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    hf_bort_config = BertConfig.from_dict(hf_bort_config_json)
 | 
			
		||||
    hf_bort_model = BertForMaskedLM(hf_bort_config)
 | 
			
		||||
    hf_bort_model.eval()
 | 
			
		||||
 | 
			
		||||
    # Parameter mapping table (Gluonnlp to Transformers)
 | 
			
		||||
    # * denotes layer index
 | 
			
		||||
    #
 | 
			
		||||
    # | Gluon Parameter                                                | Transformers Parameter
 | 
			
		||||
    # | -------------------------------------------------------------- | ----------------------
 | 
			
		||||
    # | `encoder.layer_norm.beta`                                      | `bert.embeddings.LayerNorm.bias`
 | 
			
		||||
    # | `encoder.layer_norm.gamma`                                     | `bert.embeddings.LayerNorm.weight`
 | 
			
		||||
    # | `encoder.position_weight`                                      | `bert.embeddings.position_embeddings.weight`
 | 
			
		||||
    # | `word_embed.0.weight`                                          | `bert.embeddings.word_embeddings.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_key.bias`     | `bert.encoder.layer.*.attention.self.key.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_key.weight`   | `bert.encoder.layer.*.attention.self.key.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_query.bias`   | `bert.encoder.layer.*.attention.self.query.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_query.weight` | `bert.encoder.layer.*.attention.self.query.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_value.bias`   | `bert.encoder.layer.*.attention.self.value.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.attention_cell.proj_value.weight` | `bert.encoder.layer.*.attention.self.value.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.ffn_2.bias`                   | `bert.encoder.layer.*.attention.output.dense.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.ffn_2.weight`                 | `bert.encoder.layer.*.attention.output.dense.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.layer_norm.beta`                  | `bert.encoder.layer.*.attention.output.LayerNorm.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.layer_norm.gamma`                 | `bert.encoder.layer.*.attention.output.LayerNorm.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.ffn_1.bias`                   | `bert.encoder.layer.*.intermediate.dense.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.ffn_1.weight`                 | `bert.encoder.layer.*.intermediate.dense.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.layer_norm.beta`              | `bert.encoder.layer.*.output.LayerNorm.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.ffn.layer_norm.gamma`             | `bert.encoder.layer.*.output.LayerNorm.weight`
 | 
			
		||||
    # | `encoder.transformer_cells.*.proj.bias`                        | `bert.encoder.layer.*.output.dense.bias`
 | 
			
		||||
    # | `encoder.transformer_cells.*.proj.weight`                      | `bert.encoder.layer.*.output.dense.weight`
 | 
			
		||||
 | 
			
		||||
    # Helper function to convert MXNET Arrays to PyTorch
 | 
			
		||||
    def to_torch(mx_array) -> nn.Parameter:
 | 
			
		||||
        return nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy()))
 | 
			
		||||
 | 
			
		||||
    # Check param shapes and map new HF param back
 | 
			
		||||
    def check_and_map_params(hf_param, gluon_param):
 | 
			
		||||
        shape_hf = hf_param.shape
 | 
			
		||||
 | 
			
		||||
        gluon_param = to_torch(params[gluon_param])
 | 
			
		||||
        shape_gluon = gluon_param.shape
 | 
			
		||||
 | 
			
		||||
        assert shape_hf == shape_gluon, (
 | 
			
		||||
            f"The gluon parameter {gluon_param} has shape {shape_gluon}, but expects shape {shape_hf} for Transformers"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return gluon_param
 | 
			
		||||
 | 
			
		||||
    hf_bort_model.bert.embeddings.word_embeddings.weight = check_and_map_params(
 | 
			
		||||
        hf_bort_model.bert.embeddings.word_embeddings.weight, "word_embed.0.weight"
 | 
			
		||||
    )
 | 
			
		||||
    hf_bort_model.bert.embeddings.position_embeddings.weight = check_and_map_params(
 | 
			
		||||
        hf_bort_model.bert.embeddings.position_embeddings.weight, "encoder.position_weight"
 | 
			
		||||
    )
 | 
			
		||||
    hf_bort_model.bert.embeddings.LayerNorm.bias = check_and_map_params(
 | 
			
		||||
        hf_bort_model.bert.embeddings.LayerNorm.bias, "encoder.layer_norm.beta"
 | 
			
		||||
    )
 | 
			
		||||
    hf_bort_model.bert.embeddings.LayerNorm.weight = check_and_map_params(
 | 
			
		||||
        hf_bort_model.bert.embeddings.LayerNorm.weight, "encoder.layer_norm.gamma"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Inspired by RoBERTa conversion script, we just zero them out (Bort does not use them)
 | 
			
		||||
    hf_bort_model.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(
 | 
			
		||||
        hf_bort_model.bert.embeddings.token_type_embeddings.weight.data
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for i in range(hf_bort_config.num_hidden_layers):
 | 
			
		||||
        layer: BertLayer = hf_bort_model.bert.encoder.layer[i]
 | 
			
		||||
 | 
			
		||||
        # self attention
 | 
			
		||||
        self_attn: BertSelfAttention = layer.attention.self
 | 
			
		||||
 | 
			
		||||
        self_attn.key.bias.data = check_and_map_params(
 | 
			
		||||
            self_attn.key.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_key.bias"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self_attn.key.weight.data = check_and_map_params(
 | 
			
		||||
            self_attn.key.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_key.weight"
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.query.bias.data = check_and_map_params(
 | 
			
		||||
            self_attn.query.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_query.bias"
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.query.weight.data = check_and_map_params(
 | 
			
		||||
            self_attn.query.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_query.weight"
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.bias.data = check_and_map_params(
 | 
			
		||||
            self_attn.value.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_value.bias"
 | 
			
		||||
        )
 | 
			
		||||
        self_attn.value.weight.data = check_and_map_params(
 | 
			
		||||
            self_attn.value.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_value.weight"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # self attention output
 | 
			
		||||
        self_output: BertSelfOutput = layer.attention.output
 | 
			
		||||
 | 
			
		||||
        self_output.dense.bias = check_and_map_params(
 | 
			
		||||
            self_output.dense.bias, f"encoder.transformer_cells.{i}.proj.bias"
 | 
			
		||||
        )
 | 
			
		||||
        self_output.dense.weight = check_and_map_params(
 | 
			
		||||
            self_output.dense.weight, f"encoder.transformer_cells.{i}.proj.weight"
 | 
			
		||||
        )
 | 
			
		||||
        self_output.LayerNorm.bias = check_and_map_params(
 | 
			
		||||
            self_output.LayerNorm.bias, f"encoder.transformer_cells.{i}.layer_norm.beta"
 | 
			
		||||
        )
 | 
			
		||||
        self_output.LayerNorm.weight = check_and_map_params(
 | 
			
		||||
            self_output.LayerNorm.weight, f"encoder.transformer_cells.{i}.layer_norm.gamma"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # intermediate
 | 
			
		||||
        intermediate: BertIntermediate = layer.intermediate
 | 
			
		||||
 | 
			
		||||
        intermediate.dense.bias = check_and_map_params(
 | 
			
		||||
            intermediate.dense.bias, f"encoder.transformer_cells.{i}.ffn.ffn_1.bias"
 | 
			
		||||
        )
 | 
			
		||||
        intermediate.dense.weight = check_and_map_params(
 | 
			
		||||
            intermediate.dense.weight, f"encoder.transformer_cells.{i}.ffn.ffn_1.weight"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # output
 | 
			
		||||
        bert_output: BertOutput = layer.output
 | 
			
		||||
 | 
			
		||||
        bert_output.dense.bias = check_and_map_params(
 | 
			
		||||
            bert_output.dense.bias, f"encoder.transformer_cells.{i}.ffn.ffn_2.bias"
 | 
			
		||||
        )
 | 
			
		||||
        bert_output.dense.weight = check_and_map_params(
 | 
			
		||||
            bert_output.dense.weight, f"encoder.transformer_cells.{i}.ffn.ffn_2.weight"
 | 
			
		||||
        )
 | 
			
		||||
        bert_output.LayerNorm.bias = check_and_map_params(
 | 
			
		||||
            bert_output.LayerNorm.bias, f"encoder.transformer_cells.{i}.ffn.layer_norm.beta"
 | 
			
		||||
        )
 | 
			
		||||
        bert_output.LayerNorm.weight = check_and_map_params(
 | 
			
		||||
            bert_output.LayerNorm.weight, f"encoder.transformer_cells.{i}.ffn.layer_norm.gamma"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Save space and energy 🎄
 | 
			
		||||
    hf_bort_model.half()
 | 
			
		||||
 | 
			
		||||
    # Compare output of both models
 | 
			
		||||
    tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")
 | 
			
		||||
 | 
			
		||||
    input_ids = tokenizer.encode_plus(SAMPLE_TEXT)["input_ids"]
 | 
			
		||||
 | 
			
		||||
    # Get gluon output
 | 
			
		||||
    gluon_input_ids = mx.nd.array([input_ids])
 | 
			
		||||
    output_gluon = original_bort(inputs=gluon_input_ids, token_types=[])
 | 
			
		||||
 | 
			
		||||
    # Get Transformer output (save and reload model again)
 | 
			
		||||
    hf_bort_model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    hf_bort_model = BertModel.from_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
    hf_bort_model.eval()
 | 
			
		||||
 | 
			
		||||
    input_ids = tokenizer.encode_plus(SAMPLE_TEXT, return_tensors="pt")
 | 
			
		||||
    output_hf = hf_bort_model(**input_ids)[0]
 | 
			
		||||
 | 
			
		||||
    gluon_layer = output_gluon[0].asnumpy()
 | 
			
		||||
    hf_layer = output_hf[0].detach().numpy()
 | 
			
		||||
 | 
			
		||||
    max_absolute_diff = np.max(np.abs(hf_layer - gluon_layer)).item()
 | 
			
		||||
    success = np.allclose(gluon_layer, hf_layer, atol=1e-3)
 | 
			
		||||
 | 
			
		||||
    if success:
 | 
			
		||||
        print("✔️ Both model do output the same tensors")
 | 
			
		||||
    else:
 | 
			
		||||
        print("❌ Both model do **NOT** output the same tensors")
 | 
			
		||||
        print("Absolute difference is:", max_absolute_diff)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--bort_checkpoint_path", default=None, type=str, required=True, help="Path the official Bort params file."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_bort_checkpoint_to_pytorch(args.bort_checkpoint_path, args.pytorch_dump_folder_path)
 | 
			
		||||
@ -1,319 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert DETA checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/jozhang97/DETA/tree/master"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import DetaConfig, DetaForObjectDetection, DetaImageProcessor
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_deta_config():
 | 
			
		||||
    config = DetaConfig(
 | 
			
		||||
        num_queries=900,
 | 
			
		||||
        encoder_ffn_dim=2048,
 | 
			
		||||
        decoder_ffn_dim=2048,
 | 
			
		||||
        num_feature_levels=5,
 | 
			
		||||
        assign_first_stage=True,
 | 
			
		||||
        with_box_refine=True,
 | 
			
		||||
        two_stage=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # set labels
 | 
			
		||||
    config.num_labels = 91
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    filename = "coco-detection-id2label.json"
 | 
			
		||||
    id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config):
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
 | 
			
		||||
    # stem
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    rename_keys.append(("backbone.0.body.conv1.weight", "model.backbone.model.embedder.embedder.convolution.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.bn1.weight", "model.backbone.model.embedder.embedder.normalization.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.bn1.bias", "model.backbone.model.embedder.embedder.normalization.bias"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.bn1.running_mean", "model.backbone.model.embedder.embedder.normalization.running_mean"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.bn1.running_var", "model.backbone.model.embedder.embedder.normalization.running_var"))
 | 
			
		||||
    # stages
 | 
			
		||||
    for stage_idx in range(len(config.backbone_config.depths)):
 | 
			
		||||
        for layer_idx in range(config.backbone_config.depths[stage_idx]):
 | 
			
		||||
            # shortcut
 | 
			
		||||
            if layer_idx == 0:
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            # 3 convs
 | 
			
		||||
            for i in range(3):
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                rename_keys.append(
 | 
			
		||||
                    (
 | 
			
		||||
                        f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
 | 
			
		||||
                        f"model.backbone.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
    # transformer encoder
 | 
			
		||||
    for i in range(config.encoder_layers):
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.sampling_offsets.weight", f"model.encoder.layers.{i}.self_attn.sampling_offsets.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.sampling_offsets.bias", f"model.encoder.layers.{i}.self_attn.sampling_offsets.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.attention_weights.weight", f"model.encoder.layers.{i}.self_attn.attention_weights.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.attention_weights.bias", f"model.encoder.layers.{i}.self_attn.attention_weights.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.value_proj.weight", f"model.encoder.layers.{i}.self_attn.value_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.value_proj.bias", f"model.encoder.layers.{i}.self_attn.value_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.output_proj.weight", f"model.encoder.layers.{i}.self_attn.output_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.output_proj.bias", f"model.encoder.layers.{i}.self_attn.output_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm1.weight", f"model.encoder.layers.{i}.self_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"model.encoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"model.encoder.layers.{i}.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"model.encoder.layers.{i}.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"model.encoder.layers.{i}.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"model.encoder.layers.{i}.fc2.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"model.encoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"model.encoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # transformer decoder
 | 
			
		||||
    for i in range(config.decoder_layers):
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.sampling_offsets.weight", f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.sampling_offsets.bias", f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.attention_weights.weight", f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.attention_weights.bias", f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.value_proj.weight", f"model.decoder.layers.{i}.encoder_attn.value_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.value_proj.bias", f"model.decoder.layers.{i}.encoder_attn.value_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.output_proj.weight", f"model.decoder.layers.{i}.encoder_attn.output_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.output_proj.bias", f"model.decoder.layers.{i}.encoder_attn.output_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"model.decoder.layers.{i}.self_attn.out_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"model.decoder.layers.{i}.self_attn.out_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # fmt: on
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_decoder_q_k_v(state_dict, config):
 | 
			
		||||
    # transformer decoder self-attention layers
 | 
			
		||||
    hidden_size = config.d_model
 | 
			
		||||
    for i in range(config.decoder_layers):
 | 
			
		||||
        # read in weights + bias of input projection layer of self-attention
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"transformer.decoder.layers.{i}.self_attn.in_proj_weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"transformer.decoder.layers.{i}.self_attn.in_proj_bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:hidden_size, :]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:hidden_size]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[
 | 
			
		||||
            hidden_size : hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size:, :]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_deta_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our DETA structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load config
 | 
			
		||||
    config = get_deta_config()
 | 
			
		||||
 | 
			
		||||
    # load original state dict
 | 
			
		||||
    if model_name == "deta-resnet-50":
 | 
			
		||||
        filename = "adet_checkpoint0011.pth"
 | 
			
		||||
    elif model_name == "deta-resnet-50-24-epochs":
 | 
			
		||||
        filename = "adet_2x_checkpoint0023.pth"
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Model name {model_name} not supported")
 | 
			
		||||
    checkpoint_path = hf_hub_download(repo_id="nielsr/deta-checkpoints", filename=filename)
 | 
			
		||||
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
 | 
			
		||||
 | 
			
		||||
    # rename keys
 | 
			
		||||
    rename_keys = create_rename_keys(config)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    read_in_decoder_q_k_v(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    # fix some prefixes
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if "transformer.decoder.class_embed" in key or "transformer.decoder.bbox_embed" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[key.replace("transformer.decoder", "model.decoder")] = val
 | 
			
		||||
        if "input_proj" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict["model." + key] = val
 | 
			
		||||
        if "level_embed" in key or "pos_trans" in key or "pix_trans" in key or "enc_output" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[key.replace("transformer", "model")] = val
 | 
			
		||||
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = DetaForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    # load image processor
 | 
			
		||||
    processor = DetaImageProcessor(format="coco_detection")
 | 
			
		||||
 | 
			
		||||
    # verify our conversion on image
 | 
			
		||||
    img = prepare_img()
 | 
			
		||||
    encoding = processor(images=img, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
    outputs = model(pixel_values.to(device))
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    if model_name == "deta-resnet-50":
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-7.3978, -2.5406, -4.1668], [-8.2684, -3.9933, -3.8096], [-7.0515, -3.7973, -5.8516]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.5043, 0.4973, 0.9998], [0.2542, 0.5489, 0.4748], [0.5490, 0.2765, 0.0570]])
 | 
			
		||||
    elif model_name == "deta-resnet-50-24-epochs":
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-7.1688, -2.4857, -4.8669], [-7.8630, -3.8154, -4.2674], [-7.2730, -4.1865, -5.5323]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.5021, 0.4971, 0.9994], [0.2546, 0.5486, 0.4731], [0.1686, 0.1986, 0.2142]])
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(outputs.logits[0, :3, :3], expected_logits.to(device), atol=1e-4)
 | 
			
		||||
    assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes.to(device), atol=1e-4)
 | 
			
		||||
    print("Everything ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path:
 | 
			
		||||
        # Save model and processor
 | 
			
		||||
        logger.info(f"Saving PyTorch model and processor to {pytorch_dump_folder_path}...")
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    # Push to hub
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print("Pushing model and processor to hub...")
 | 
			
		||||
        model.push_to_hub(f"jozhang97/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"jozhang97/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="deta-resnet-50",
 | 
			
		||||
        choices=["deta-resnet-50", "deta-resnet-50-24-epochs"],
 | 
			
		||||
        help="Name of the model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the folder to output PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_deta_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,326 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
"""Convert DETA checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/jozhang97/DETA/tree/master"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from transformers import DetaConfig, DetaForObjectDetection, DetaImageProcessor, SwinConfig
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.set_verbosity_info()
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_deta_config(model_name):
 | 
			
		||||
    backbone_config = SwinConfig(
 | 
			
		||||
        embed_dim=192,
 | 
			
		||||
        depths=(2, 2, 18, 2),
 | 
			
		||||
        num_heads=(6, 12, 24, 48),
 | 
			
		||||
        window_size=12,
 | 
			
		||||
        out_features=["stage2", "stage3", "stage4"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = DetaConfig(
 | 
			
		||||
        backbone_config=backbone_config,
 | 
			
		||||
        num_queries=900,
 | 
			
		||||
        encoder_ffn_dim=2048,
 | 
			
		||||
        decoder_ffn_dim=2048,
 | 
			
		||||
        num_feature_levels=5,
 | 
			
		||||
        assign_first_stage=True,
 | 
			
		||||
        with_box_refine=True,
 | 
			
		||||
        two_stage=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # set labels
 | 
			
		||||
    repo_id = "huggingface/label-files"
 | 
			
		||||
    if "o365" in model_name:
 | 
			
		||||
        num_labels = 366
 | 
			
		||||
        filename = "object365-id2label.json"
 | 
			
		||||
    else:
 | 
			
		||||
        num_labels = 91
 | 
			
		||||
        filename = "coco-detection-id2label.json"
 | 
			
		||||
 | 
			
		||||
    config.num_labels = num_labels
 | 
			
		||||
    id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
 | 
			
		||||
    id2label = {int(k): v for k, v in id2label.items()}
 | 
			
		||||
    config.id2label = id2label
 | 
			
		||||
    config.label2id = {v: k for k, v in id2label.items()}
 | 
			
		||||
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# here we list all keys to be renamed (original name on the left, our name on the right)
 | 
			
		||||
def create_rename_keys(config):
 | 
			
		||||
    rename_keys = []
 | 
			
		||||
 | 
			
		||||
    # stem
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    rename_keys.append(("backbone.0.body.patch_embed.proj.weight", "model.backbone.model.embeddings.patch_embeddings.projection.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.patch_embed.proj.bias", "model.backbone.model.embeddings.patch_embeddings.projection.bias"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.patch_embed.norm.weight", "model.backbone.model.embeddings.norm.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.patch_embed.norm.bias", "model.backbone.model.embeddings.norm.bias"))
 | 
			
		||||
    # stages
 | 
			
		||||
    for i in range(len(config.backbone_config.depths)):
 | 
			
		||||
        for j in range(config.backbone_config.depths[i]):
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.norm1.weight", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.norm1.bias", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.attn.relative_position_bias_table", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.attn.relative_position_index", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.attn.proj.weight", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.attn.proj.bias", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.norm2.weight", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.norm2.bias", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.mlp.fc1.weight", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.mlp.fc1.bias", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.mlp.fc2.weight", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.blocks.{j}.mlp.fc2.bias", f"model.backbone.model.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
 | 
			
		||||
 | 
			
		||||
        if i < 3:
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.downsample.reduction.weight", f"model.backbone.model.encoder.layers.{i}.downsample.reduction.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.downsample.norm.weight", f"model.backbone.model.encoder.layers.{i}.downsample.norm.weight"))
 | 
			
		||||
            rename_keys.append((f"backbone.0.body.layers.{i}.downsample.norm.bias", f"model.backbone.model.encoder.layers.{i}.downsample.norm.bias"))
 | 
			
		||||
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm1.weight", "model.backbone.model.hidden_states_norms.stage2.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm1.bias", "model.backbone.model.hidden_states_norms.stage2.bias"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm2.weight", "model.backbone.model.hidden_states_norms.stage3.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm2.bias", "model.backbone.model.hidden_states_norms.stage3.bias"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm3.weight", "model.backbone.model.hidden_states_norms.stage4.weight"))
 | 
			
		||||
    rename_keys.append(("backbone.0.body.norm3.bias", "model.backbone.model.hidden_states_norms.stage4.bias"))
 | 
			
		||||
 | 
			
		||||
    # transformer encoder
 | 
			
		||||
    for i in range(config.encoder_layers):
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.sampling_offsets.weight", f"model.encoder.layers.{i}.self_attn.sampling_offsets.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.sampling_offsets.bias", f"model.encoder.layers.{i}.self_attn.sampling_offsets.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.attention_weights.weight", f"model.encoder.layers.{i}.self_attn.attention_weights.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.attention_weights.bias", f"model.encoder.layers.{i}.self_attn.attention_weights.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.value_proj.weight", f"model.encoder.layers.{i}.self_attn.value_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.value_proj.bias", f"model.encoder.layers.{i}.self_attn.value_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.output_proj.weight", f"model.encoder.layers.{i}.self_attn.output_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.self_attn.output_proj.bias", f"model.encoder.layers.{i}.self_attn.output_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm1.weight", f"model.encoder.layers.{i}.self_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm1.bias", f"model.encoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"model.encoder.layers.{i}.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"model.encoder.layers.{i}.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"model.encoder.layers.{i}.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"model.encoder.layers.{i}.fc2.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm2.weight", f"model.encoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"model.encoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # transformer decoder
 | 
			
		||||
    for i in range(config.decoder_layers):
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.sampling_offsets.weight", f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.sampling_offsets.bias", f"model.decoder.layers.{i}.encoder_attn.sampling_offsets.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.attention_weights.weight", f"model.decoder.layers.{i}.encoder_attn.attention_weights.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.attention_weights.bias", f"model.decoder.layers.{i}.encoder_attn.attention_weights.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.value_proj.weight", f"model.decoder.layers.{i}.encoder_attn.value_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.value_proj.bias", f"model.decoder.layers.{i}.encoder_attn.value_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.output_proj.weight", f"model.decoder.layers.{i}.encoder_attn.output_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.cross_attn.output_proj.bias", f"model.decoder.layers.{i}.encoder_attn.output_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm1.weight", f"model.decoder.layers.{i}.encoder_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm1.bias", f"model.decoder.layers.{i}.encoder_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.self_attn.out_proj.weight", f"model.decoder.layers.{i}.self_attn.out_proj.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"model.decoder.layers.{i}.self_attn.out_proj.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm2.weight", f"model.decoder.layers.{i}.self_attn_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm2.bias", f"model.decoder.layers.{i}.self_attn_layer_norm.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"model.decoder.layers.{i}.fc1.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"model.decoder.layers.{i}.fc1.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"model.decoder.layers.{i}.fc2.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"model.decoder.layers.{i}.fc2.bias"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm3.weight", f"model.decoder.layers.{i}.final_layer_norm.weight"))
 | 
			
		||||
        rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"model.decoder.layers.{i}.final_layer_norm.bias"))
 | 
			
		||||
 | 
			
		||||
    # fmt: on
 | 
			
		||||
 | 
			
		||||
    return rename_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(dct, old, new):
 | 
			
		||||
    val = dct.pop(old)
 | 
			
		||||
    dct[new] = val
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# we split up the matrix of each encoder layer into queries, keys and values
 | 
			
		||||
def read_in_swin_q_k_v(state_dict, backbone_config):
 | 
			
		||||
    num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
 | 
			
		||||
    for i in range(len(backbone_config.depths)):
 | 
			
		||||
        dim = num_features[i]
 | 
			
		||||
        for j in range(backbone_config.depths[i]):
 | 
			
		||||
            # fmt: off
 | 
			
		||||
            # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
 | 
			
		||||
            in_proj_weight = state_dict.pop(f"backbone.0.body.layers.{i}.blocks.{j}.attn.qkv.weight")
 | 
			
		||||
            in_proj_bias = state_dict.pop(f"backbone.0.body.layers.{i}.blocks.{j}.attn.qkv.bias")
 | 
			
		||||
            # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
 | 
			
		||||
                dim : dim * 2, :
 | 
			
		||||
            ]
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
 | 
			
		||||
                dim : dim * 2
 | 
			
		||||
            ]
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
 | 
			
		||||
                -dim :, :
 | 
			
		||||
            ]
 | 
			
		||||
            state_dict[f"model.backbone.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
 | 
			
		||||
            # fmt: on
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_in_decoder_q_k_v(state_dict, config):
 | 
			
		||||
    # transformer decoder self-attention layers
 | 
			
		||||
    hidden_size = config.d_model
 | 
			
		||||
    for i in range(config.decoder_layers):
 | 
			
		||||
        # read in weights + bias of input projection layer of self-attention
 | 
			
		||||
        in_proj_weight = state_dict.pop(f"transformer.decoder.layers.{i}.self_attn.in_proj_weight")
 | 
			
		||||
        in_proj_bias = state_dict.pop(f"transformer.decoder.layers.{i}.self_attn.in_proj_bias")
 | 
			
		||||
        # next, add query, keys and values (in that order) to the state dict
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:hidden_size, :]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:hidden_size]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[
 | 
			
		||||
            hidden_size : hidden_size * 2, :
 | 
			
		||||
        ]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size:, :]
 | 
			
		||||
        state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on an image of cute cats
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    im = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def convert_deta_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
 | 
			
		||||
    """
 | 
			
		||||
    Copy/paste/tweak model's weights to our DETA structure.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load config
 | 
			
		||||
    config = get_deta_config(model_name)
 | 
			
		||||
 | 
			
		||||
    # load original state dict
 | 
			
		||||
    if model_name == "deta-swin-large":
 | 
			
		||||
        checkpoint_path = hf_hub_download(repo_id="nielsr/deta-checkpoints", filename="adet_swin_ft.pth")
 | 
			
		||||
    elif model_name == "deta-swin-large-o365":
 | 
			
		||||
        checkpoint_path = hf_hub_download(repo_id="jozhang97/deta-swin-l-o365", filename="deta_swin_pt_o365.pth")
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Model name {model_name} not supported")
 | 
			
		||||
 | 
			
		||||
    state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
 | 
			
		||||
 | 
			
		||||
    # original state dict
 | 
			
		||||
    for name, param in state_dict.items():
 | 
			
		||||
        print(name, param.shape)
 | 
			
		||||
 | 
			
		||||
    # rename keys
 | 
			
		||||
    rename_keys = create_rename_keys(config)
 | 
			
		||||
    for src, dest in rename_keys:
 | 
			
		||||
        rename_key(state_dict, src, dest)
 | 
			
		||||
    read_in_swin_q_k_v(state_dict, config.backbone_config)
 | 
			
		||||
    read_in_decoder_q_k_v(state_dict, config)
 | 
			
		||||
 | 
			
		||||
    # fix some prefixes
 | 
			
		||||
    for key in state_dict.copy().keys():
 | 
			
		||||
        if "transformer.decoder.class_embed" in key or "transformer.decoder.bbox_embed" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[key.replace("transformer.decoder", "model.decoder")] = val
 | 
			
		||||
        if "input_proj" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict["model." + key] = val
 | 
			
		||||
        if "level_embed" in key or "pos_trans" in key or "pix_trans" in key or "enc_output" in key:
 | 
			
		||||
            val = state_dict.pop(key)
 | 
			
		||||
            state_dict[key.replace("transformer", "model")] = val
 | 
			
		||||
 | 
			
		||||
    # finally, create HuggingFace model and load state dict
 | 
			
		||||
    model = DetaForObjectDetection(config)
 | 
			
		||||
    model.load_state_dict(state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    device = "cuda" if torch.cuda.is_available() else "cpu"
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    # load image processor
 | 
			
		||||
    processor = DetaImageProcessor(format="coco_detection")
 | 
			
		||||
 | 
			
		||||
    # verify our conversion on image
 | 
			
		||||
    img = prepare_img()
 | 
			
		||||
    encoding = processor(images=img, return_tensors="pt")
 | 
			
		||||
    pixel_values = encoding["pixel_values"]
 | 
			
		||||
    outputs = model(pixel_values.to(device))
 | 
			
		||||
 | 
			
		||||
    # verify logits
 | 
			
		||||
    print("Logits:", outputs.logits[0, :3, :3])
 | 
			
		||||
    print("Boxes:", outputs.pred_boxes[0, :3, :3])
 | 
			
		||||
    if model_name == "deta-swin-large":
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-7.6308, -2.8485, -5.3737], [-7.2037, -4.5505, -4.8027], [-7.2943, -4.2611, -4.6617]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.4987, 0.4969, 0.9999], [0.2549, 0.5498, 0.4805], [0.5498, 0.2757, 0.0569]])
 | 
			
		||||
    elif model_name == "deta-swin-large-o365":
 | 
			
		||||
        expected_logits = torch.tensor(
 | 
			
		||||
            [[-8.0122, -3.5720, -4.9717], [-8.1547, -3.6886, -4.6389], [-7.6610, -3.6194, -5.0134]]
 | 
			
		||||
        )
 | 
			
		||||
        expected_boxes = torch.tensor([[0.2523, 0.5549, 0.4881], [0.7715, 0.4149, 0.4601], [0.5503, 0.2753, 0.0575]])
 | 
			
		||||
    assert torch.allclose(outputs.logits[0, :3, :3], expected_logits.to(device), atol=1e-4)
 | 
			
		||||
    assert torch.allclose(outputs.pred_boxes[0, :3, :3], expected_boxes.to(device), atol=1e-4)
 | 
			
		||||
    print("Everything ok!")
 | 
			
		||||
 | 
			
		||||
    if pytorch_dump_folder_path:
 | 
			
		||||
        # Save model and processor
 | 
			
		||||
        logger.info(f"Saving PyTorch model and processor to {pytorch_dump_folder_path}...")
 | 
			
		||||
        Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
 | 
			
		||||
        model.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
        processor.save_pretrained(pytorch_dump_folder_path)
 | 
			
		||||
 | 
			
		||||
    # Push to hub
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print("Pushing model and processor to hub...")
 | 
			
		||||
        model.push_to_hub(f"jozhang97/{model_name}")
 | 
			
		||||
        processor.push_to_hub(f"jozhang97/{model_name}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_name",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="deta-swin-large",
 | 
			
		||||
        choices=["deta-swin-large", "deta-swin-large-o365"],
 | 
			
		||||
        help="Name of the model you'd like to convert.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_folder_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Path to the folder to output PyTorch model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_deta_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
 | 
			
		||||
@ -1,252 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert EfficientFormer checkpoints from the original repository.
 | 
			
		||||
 | 
			
		||||
URL: https://github.com/snap-research/EfficientFormer
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import re
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import requests
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    EfficientFormerConfig,
 | 
			
		||||
    EfficientFormerForImageClassificationWithTeacher,
 | 
			
		||||
    EfficientFormerImageProcessor,
 | 
			
		||||
)
 | 
			
		||||
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rename_key(old_name, num_meta4D_last_stage):
 | 
			
		||||
    new_name = old_name
 | 
			
		||||
 | 
			
		||||
    if "patch_embed" in old_name:
 | 
			
		||||
        _, layer, param = old_name.split(".")
 | 
			
		||||
 | 
			
		||||
        if layer == "0":
 | 
			
		||||
            new_name = old_name.replace("0", "convolution1")
 | 
			
		||||
        elif layer == "1":
 | 
			
		||||
            new_name = old_name.replace("1", "batchnorm_before")
 | 
			
		||||
        elif layer == "3":
 | 
			
		||||
            new_name = old_name.replace("3", "convolution2")
 | 
			
		||||
        else:
 | 
			
		||||
            new_name = old_name.replace("4", "batchnorm_after")
 | 
			
		||||
 | 
			
		||||
    if "network" in old_name and re.search(r"\d\.\d", old_name):
 | 
			
		||||
        two_digit_num = r"\b\d{2}\b"
 | 
			
		||||
        if bool(re.search(two_digit_num, old_name)):
 | 
			
		||||
            match = re.search(r"\d\.\d\d.", old_name).group()
 | 
			
		||||
        else:
 | 
			
		||||
            match = re.search(r"\d\.\d.", old_name).group()
 | 
			
		||||
        if int(match[0]) < 6:
 | 
			
		||||
            trimmed_name = old_name.replace(match, "")
 | 
			
		||||
            trimmed_name = trimmed_name.replace("network", match[0] + ".meta4D_layers.blocks." + match[2:-1])
 | 
			
		||||
            new_name = "intermediate_stages." + trimmed_name
 | 
			
		||||
        else:
 | 
			
		||||
            trimmed_name = old_name.replace(match, "")
 | 
			
		||||
            if int(match[2]) < num_meta4D_last_stage:
 | 
			
		||||
                trimmed_name = trimmed_name.replace("network", "meta4D_layers.blocks." + match[2])
 | 
			
		||||
            else:
 | 
			
		||||
                layer_index = str(int(match[2]) - num_meta4D_last_stage)
 | 
			
		||||
                trimmed_name = trimmed_name.replace("network", "meta3D_layers.blocks." + layer_index)
 | 
			
		||||
                if "norm1" in old_name:
 | 
			
		||||
                    trimmed_name = trimmed_name.replace("norm1", "layernorm1")
 | 
			
		||||
                elif "norm2" in old_name:
 | 
			
		||||
                    trimmed_name = trimmed_name.replace("norm2", "layernorm2")
 | 
			
		||||
                elif "fc1" in old_name:
 | 
			
		||||
                    trimmed_name = trimmed_name.replace("fc1", "linear_in")
 | 
			
		||||
                elif "fc2" in old_name:
 | 
			
		||||
                    trimmed_name = trimmed_name.replace("fc2", "linear_out")
 | 
			
		||||
 | 
			
		||||
            new_name = "last_stage." + trimmed_name
 | 
			
		||||
 | 
			
		||||
    elif "network" in old_name and re.search(r".\d.", old_name):
 | 
			
		||||
        new_name = old_name.replace("network", "intermediate_stages")
 | 
			
		||||
 | 
			
		||||
    if "fc" in new_name:
 | 
			
		||||
        new_name = new_name.replace("fc", "convolution")
 | 
			
		||||
    elif ("norm1" in new_name) and ("layernorm1" not in new_name):
 | 
			
		||||
        new_name = new_name.replace("norm1", "batchnorm_before")
 | 
			
		||||
    elif ("norm2" in new_name) and ("layernorm2" not in new_name):
 | 
			
		||||
        new_name = new_name.replace("norm2", "batchnorm_after")
 | 
			
		||||
    if "proj" in new_name:
 | 
			
		||||
        new_name = new_name.replace("proj", "projection")
 | 
			
		||||
    if "dist_head" in new_name:
 | 
			
		||||
        new_name = new_name.replace("dist_head", "distillation_classifier")
 | 
			
		||||
    elif "head" in new_name:
 | 
			
		||||
        new_name = new_name.replace("head", "classifier")
 | 
			
		||||
    elif "patch_embed" in new_name:
 | 
			
		||||
        new_name = "efficientformer." + new_name
 | 
			
		||||
    elif new_name == "norm.weight" or new_name == "norm.bias":
 | 
			
		||||
        new_name = new_name.replace("norm", "layernorm")
 | 
			
		||||
        new_name = "efficientformer." + new_name
 | 
			
		||||
    else:
 | 
			
		||||
        new_name = "efficientformer.encoder." + new_name
 | 
			
		||||
 | 
			
		||||
    return new_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_torch_checkpoint(checkpoint, num_meta4D_last_stage):
 | 
			
		||||
    for key in checkpoint.copy().keys():
 | 
			
		||||
        val = checkpoint.pop(key)
 | 
			
		||||
        checkpoint[rename_key(key, num_meta4D_last_stage)] = val
 | 
			
		||||
 | 
			
		||||
    return checkpoint
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# We will verify our results on a COCO image
 | 
			
		||||
def prepare_img():
 | 
			
		||||
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 | 
			
		||||
    image = Image.open(requests.get(url, stream=True).raw)
 | 
			
		||||
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_efficientformer_checkpoint(
 | 
			
		||||
    checkpoint_path: Path, efficientformer_config_file: Path, pytorch_dump_path: Path, push_to_hub: bool
 | 
			
		||||
):
 | 
			
		||||
    orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
 | 
			
		||||
    config = EfficientFormerConfig.from_json_file(efficientformer_config_file)
 | 
			
		||||
    model = EfficientFormerForImageClassificationWithTeacher(config)
 | 
			
		||||
    model_name = "_".join(checkpoint_path.split("/")[-1].split(".")[0].split("_")[:-1])
 | 
			
		||||
 | 
			
		||||
    num_meta4D_last_stage = config.depths[-1] - config.num_meta3d_blocks + 1
 | 
			
		||||
    new_state_dict = convert_torch_checkpoint(orig_state_dict, num_meta4D_last_stage)
 | 
			
		||||
 | 
			
		||||
    model.load_state_dict(new_state_dict)
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    pillow_resamplings = {
 | 
			
		||||
        "bilinear": PILImageResampling.BILINEAR,
 | 
			
		||||
        "bicubic": PILImageResampling.BICUBIC,
 | 
			
		||||
        "nearest": PILImageResampling.NEAREST,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # prepare image
 | 
			
		||||
    image = prepare_img()
 | 
			
		||||
    image_size = 256
 | 
			
		||||
    crop_size = 224
 | 
			
		||||
    processor = EfficientFormerImageProcessor(
 | 
			
		||||
        size={"shortest_edge": image_size},
 | 
			
		||||
        crop_size={"height": crop_size, "width": crop_size},
 | 
			
		||||
        resample=pillow_resamplings["bicubic"],
 | 
			
		||||
    )
 | 
			
		||||
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
 | 
			
		||||
 | 
			
		||||
    # original processing pipeline
 | 
			
		||||
    image_transforms = Compose(
 | 
			
		||||
        [
 | 
			
		||||
            Resize(image_size, interpolation=pillow_resamplings["bicubic"]),
 | 
			
		||||
            CenterCrop(crop_size),
 | 
			
		||||
            ToTensor(),
 | 
			
		||||
            Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    original_pixel_values = image_transforms(image).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(original_pixel_values, pixel_values)
 | 
			
		||||
 | 
			
		||||
    outputs = model(pixel_values)
 | 
			
		||||
    logits = outputs.logits
 | 
			
		||||
 | 
			
		||||
    expected_shape = (1, 1000)
 | 
			
		||||
 | 
			
		||||
    if "l1" in model_name:
 | 
			
		||||
        expected_logits = torch.Tensor(
 | 
			
		||||
            [-0.1312, 0.4353, -1.0499, -0.5124, 0.4183, -0.6793, -1.3777, -0.0893, -0.7358, -2.4328]
 | 
			
		||||
        )
 | 
			
		||||
        assert torch.allclose(logits[0, :10], expected_logits, atol=1e-3)
 | 
			
		||||
        assert logits.shape == expected_shape
 | 
			
		||||
    elif "l3" in model_name:
 | 
			
		||||
        expected_logits = torch.Tensor(
 | 
			
		||||
            [-1.3150, -1.5456, -1.2556, -0.8496, -0.7127, -0.7897, -0.9728, -0.3052, 0.3751, -0.3127]
 | 
			
		||||
        )
 | 
			
		||||
        assert torch.allclose(logits[0, :10], expected_logits, atol=1e-3)
 | 
			
		||||
        assert logits.shape == expected_shape
 | 
			
		||||
    elif "l7" in model_name:
 | 
			
		||||
        expected_logits = torch.Tensor(
 | 
			
		||||
            [-1.0283, -1.4131, -0.5644, -1.3115, -0.5785, -1.2049, -0.7528, 0.1992, -0.3822, -0.0878]
 | 
			
		||||
        )
 | 
			
		||||
        assert logits.shape == expected_shape
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Unknown model checkpoint: {checkpoint_path}. Supported version of efficientformer are l1, l3 and l7"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Save Checkpoints
 | 
			
		||||
    Path(pytorch_dump_path).mkdir(exist_ok=True)
 | 
			
		||||
    model.save_pretrained(pytorch_dump_path)
 | 
			
		||||
    print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
 | 
			
		||||
    processor.save_pretrained(pytorch_dump_path)
 | 
			
		||||
    print(f"Processor successfuly saved at {pytorch_dump_path}")
 | 
			
		||||
 | 
			
		||||
    if push_to_hub:
 | 
			
		||||
        print("Pushing model to the hub...")
 | 
			
		||||
 | 
			
		||||
        model.push_to_hub(
 | 
			
		||||
            repo_id=f"Bearnardd/{pytorch_dump_path}",
 | 
			
		||||
            commit_message="Add model",
 | 
			
		||||
            use_temp_dir=True,
 | 
			
		||||
        )
 | 
			
		||||
        processor.push_to_hub(
 | 
			
		||||
            repo_id=f"Bearnardd/{pytorch_dump_path}",
 | 
			
		||||
            commit_message="Add image processor",
 | 
			
		||||
            use_temp_dir=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    # Required parameters
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_model_path",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Path to EfficientFormer pytorch checkpoint.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--config_file",
 | 
			
		||||
        default=None,
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="The json file for EfficientFormer model config.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--no-push_to_hub",
 | 
			
		||||
        dest="push_to_hub",
 | 
			
		||||
        action="store_false",
 | 
			
		||||
        help="Do not push model and image processor to the hub",
 | 
			
		||||
    )
 | 
			
		||||
    parser.set_defaults(push_to_hub=True)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_efficientformer_checkpoint(
 | 
			
		||||
        checkpoint_path=args.pytorch_model_path,
 | 
			
		||||
        efficientformer_config_file=args.config_file,
 | 
			
		||||
        pytorch_dump_path=args.pytorch_dump_path,
 | 
			
		||||
        push_to_hub=args.push_to_hub,
 | 
			
		||||
    )
 | 
			
		||||
@ -1,181 +0,0 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2023 The HuggingFace Inc. team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model."""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_tf_gptsan_to_pt(args):
 | 
			
		||||
    parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
 | 
			
		||||
    params = json.loads(open(parameter_file).read())
 | 
			
		||||
    if not params:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
 | 
			
		||||
        )
 | 
			
		||||
    if not args.output.endswith(".pt"):
 | 
			
		||||
        args.output = args.output + ".pt"
 | 
			
		||||
    new_state = OrderedDict()
 | 
			
		||||
    with tf.device("/CPU:0"):
 | 
			
		||||
        reader = tf.train.load_checkpoint(args.tf_model_dir)
 | 
			
		||||
        shapes = reader.get_variable_to_shape_map()
 | 
			
		||||
        for key_name in shapes.keys():
 | 
			
		||||
            vnp = reader.get_tensor(key_name).astype(np.float16)
 | 
			
		||||
            if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"):
 | 
			
		||||
                continue
 | 
			
		||||
            if key_name.startswith("pasts/"):
 | 
			
		||||
                if key_name.startswith("pasts/mlp"):
 | 
			
		||||
                    player = int(key_name[9])
 | 
			
		||||
                elif key_name.startswith("pasts/out"):
 | 
			
		||||
                    player = 8
 | 
			
		||||
                name = "model.sqout.%d.weight" % (player * 2)  # enter to nn.Sequencial with Tanh, so 2 at a time
 | 
			
		||||
                state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/moe"):
 | 
			
		||||
                player = int(key_name[9:].split("/")[0])
 | 
			
		||||
                if key_name.endswith("/switch_gating/kernel"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player
 | 
			
		||||
                    state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/softmlp/kernel"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player
 | 
			
		||||
                    state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"):
 | 
			
		||||
                    nlayer = key_name[-9:-7]
 | 
			
		||||
                    for i in range(16):
 | 
			
		||||
                        name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer)
 | 
			
		||||
                        state = (
 | 
			
		||||
                            vnp[i].transpose([1, 0]).copy()
 | 
			
		||||
                        )  # In Mesh-Tensorflow, it is one array, so it is divided
 | 
			
		||||
                        new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/mlp"):
 | 
			
		||||
                player = int(key_name[9:].split("/")[0])
 | 
			
		||||
                if key_name.endswith("/p1/kernel"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player
 | 
			
		||||
                    state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/p1/bias"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/p2/kernel"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player
 | 
			
		||||
                    state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/p2/bias"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/ln"):
 | 
			
		||||
                player = int(key_name[8:].split("/")[0])
 | 
			
		||||
                if key_name.endswith("/b"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.norm.bias" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/g"):
 | 
			
		||||
                    name = "model.blocks.%d.feed_forward.norm.weight" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/att"):
 | 
			
		||||
                player = int(key_name[9:].split("/")[0])
 | 
			
		||||
                if key_name.endswith("/qkv/kernel"):
 | 
			
		||||
                    state = vnp.copy()  # Compute same dimension as Mesh-tensorflow using einsum
 | 
			
		||||
                    state_q = state[:, 0, :, :]
 | 
			
		||||
                    state_k = state[:, 1, :, :]
 | 
			
		||||
                    state_v = state[:, 2, :, :]
 | 
			
		||||
                    state_q = (
 | 
			
		||||
                        state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]])
 | 
			
		||||
                        .transpose([1, 0])
 | 
			
		||||
                        .copy()
 | 
			
		||||
                    )  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    state_k = (
 | 
			
		||||
                        state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]])
 | 
			
		||||
                        .transpose([1, 0])
 | 
			
		||||
                        .copy()
 | 
			
		||||
                    )  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    state_v = (
 | 
			
		||||
                        state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]])
 | 
			
		||||
                        .transpose([1, 0])
 | 
			
		||||
                        .copy()
 | 
			
		||||
                    )  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player
 | 
			
		||||
                    new_state[name] = torch.tensor(state_q)
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player
 | 
			
		||||
                    new_state[name] = torch.tensor(state_k)
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player
 | 
			
		||||
                    new_state[name] = torch.tensor(state_v)
 | 
			
		||||
                elif key_name.endswith("/o/kernel"):
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player
 | 
			
		||||
                    state = (
 | 
			
		||||
                        vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy()
 | 
			
		||||
                    )  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/an"):
 | 
			
		||||
                player = int(key_name[8:].split("/")[0])
 | 
			
		||||
                if key_name.endswith("/b"):
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.norm.bias" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
                elif key_name.endswith("/g"):
 | 
			
		||||
                    name = "model.blocks.%d.self_attn.norm.weight" % player
 | 
			
		||||
                    state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif (
 | 
			
		||||
                key_name.startswith("model/wte")
 | 
			
		||||
                or key_name.startswith("model/wpe")
 | 
			
		||||
                or key_name.startswith("model/ete")
 | 
			
		||||
            ):
 | 
			
		||||
                nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[
 | 
			
		||||
                    key_name[-3:]
 | 
			
		||||
                ]
 | 
			
		||||
                name = "model.%s.weight" % nlayer
 | 
			
		||||
                state = vnp.copy()  # same in embedded
 | 
			
		||||
                new_state[name] = torch.tensor(state)
 | 
			
		||||
                if key_name.startswith("model/wte"):
 | 
			
		||||
                    name = "lm_head.weight"
 | 
			
		||||
                    state = vnp.copy()  # same in embedded
 | 
			
		||||
                    new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name.startswith("model/wob"):
 | 
			
		||||
                name = "final_logits_bias"
 | 
			
		||||
                state = vnp.copy()  # same in embedded
 | 
			
		||||
                state = state.reshape((1, -1))
 | 
			
		||||
                new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name == "model/dense/kernel":
 | 
			
		||||
                name = "model.last_project.weight"
 | 
			
		||||
                state = vnp.transpose([1, 0]).copy()  # Mesh-Tensorflow is a diagonal matrix
 | 
			
		||||
                new_state[name] = torch.tensor(state)
 | 
			
		||||
            elif key_name == "model/dense_1/bias":
 | 
			
		||||
                name = "model.last_project.bias"
 | 
			
		||||
                state = vnp.copy()  # same because it is one dimensional
 | 
			
		||||
                new_state[name] = torch.tensor(state)
 | 
			
		||||
    torch.save(new_state, args.output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
 | 
			
		||||
    parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    convert_tf_gptsan_to_pt(args)
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user