mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-15 23:25:03 +08:00
Compare commits
5 Commits
v4.57.1-te
...
add-missin
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e236b275d | |||
| afa5b2b9d5 | |||
| 6f6095e0cf | |||
| c4cfc2e023 | |||
| 5c6d6bed4d |
@ -46,8 +46,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
@ -98,8 +98,8 @@ jobs:
|
||||
- run: uv pip install -U -e .
|
||||
- run: echo 'export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)"' >> "$BASH_ENV" && source "$BASH_ENV"
|
||||
- run: mkdir -p test_preparation
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt
|
||||
- run: python utils/tests_fetcher.py --filter_tests
|
||||
- run: python utils/tests_fetcher.py --fetch_all | tee tests_fetched_summary.txt || true
|
||||
- run: python utils/tests_fetcher.py --filter_tests || true
|
||||
- run: export "GIT_COMMIT_MESSAGE=$(git show -s --format=%s)" && echo $GIT_COMMIT_MESSAGE && python .circleci/create_circleci_config.py --fetcher_folder test_preparation
|
||||
- run: |
|
||||
if [ ! -s test_preparation/generated_config.yml ]; then
|
||||
|
||||
1
Makefile
1
Makefile
@ -45,6 +45,7 @@ repo-consistency:
|
||||
python utils/check_modular_conversion.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
python utils/check_init_weights_data.py
|
||||
python utils/check_inits.py
|
||||
python utils/check_pipeline_typing.py
|
||||
python utils/check_config_docstrings.py
|
||||
|
||||
@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
|
||||
@ -533,9 +533,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf
|
||||
|
||||
@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
|
||||
@ -339,9 +339,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
### Convert checkpoints to Transformers
|
||||
|
||||
@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.
|
||||
|
||||
@ -170,7 +170,7 @@ Per quanto riguarda la classe `TrainingArguments`:
|
||||
- L'argomento `evaluate_during_training` di `TrainingArguments` è deprecato a favore di `eval_strategy`.
|
||||
|
||||
Per quanto riguarda il modello Transfo-XL:
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_words_embeddings`.
|
||||
- L'attributo di configurazione `tie_weight` di Transfo-XL diventa `tie_word_embeddings`.
|
||||
- Il metodo di modellazione `reset_length` di Transfo-XL diventa `reset_memory_length`.
|
||||
|
||||
Per quanto riguarda le pipeline:
|
||||
|
||||
@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
|
||||
@ -431,9 +431,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。
|
||||
|
||||
@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
```
|
||||
|
||||
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
|
||||
@ -371,9 +371,9 @@ def _init_weights(self, module):
|
||||
module.project_hid._is_hf_initialized = True
|
||||
module.project_q._is_hf_initialized = True
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
```
|
||||
|
||||
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.
|
||||
|
||||
@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.
|
||||
|
||||
@ -502,16 +502,10 @@ class DummyBertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, DummyBertLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
||||
if "RMSNorm" in module.__class__.__name__:
|
||||
module.weight.data.zero_()
|
||||
module.weight.zero_()
|
||||
|
||||
|
||||
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
|
||||
|
||||
@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
def token_type_ids_mask_function(
|
||||
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
@ -440,7 +440,15 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -505,16 +505,10 @@ class RobertaLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, RobertaLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.xavier_uniform_(module.output_proj.weight.data)
|
||||
nn.init.constant_(module.output_proj.bias.data, 0.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if hasattr(module, "reference_points") and not self.config.two_stage:
|
||||
|
||||
@ -19,7 +19,15 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
prefix = "model.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in self.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(self._tied_weights_keys, dict):
|
||||
self._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
self._tied_weights_keys = prefixed_mapping
|
||||
|
||||
self.post_init()
|
||||
|
||||
|
||||
@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
serializable_config_dict["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(serializable_config_dict)
|
||||
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
|
||||
if hasattr(self, "quantization_config"):
|
||||
output["quantization_config"] = (
|
||||
self.quantization_config.to_dict()
|
||||
if not isinstance(self.quantization_config, dict)
|
||||
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
|
||||
else self.quantization_config
|
||||
)
|
||||
self.dict_dtype_to_str(output)
|
||||
|
||||
136
src/transformers/conversion_mapping.py
Normal file
136
src/transformers/conversion_mapping.py
Normal file
@ -0,0 +1,136 @@
|
||||
# coding=utf-8
|
||||
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def _build_checkpoint_conversion_mapping():
|
||||
mapping = {
|
||||
"mixtral": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w1.weight",
|
||||
"block_sparse_moe.experts.*.w3.weight",
|
||||
], # you give me a list of 2 keys, I collect a list of a list of tensors
|
||||
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"block_sparse_moe.experts.*.w2.weight",
|
||||
],
|
||||
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
|
||||
operations=[
|
||||
MergeModulelist(
|
||||
dim=0
|
||||
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
|
||||
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
|
||||
),
|
||||
# WeightConverter(
|
||||
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
|
||||
# "self_attn.qkv_proj",
|
||||
# operations=[Concatenate(dim=0)], # more like stack?
|
||||
# ),
|
||||
WeightConverter("*.block_sparse_moe.", "*.mlp."),
|
||||
],
|
||||
"qwen2_moe": [
|
||||
WeightConverter(
|
||||
source_keys=[
|
||||
"mlp.experts.*.gate_proj.weight",
|
||||
"mlp.experts.*.up_proj.weight",
|
||||
],
|
||||
target_keys="mlp.experts.gate_up_proj",
|
||||
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys=["mlp.experts.*.down_proj.weight"],
|
||||
target_keys="mlp.experts.down_proj",
|
||||
operations=[MergeModulelist(dim=0)],
|
||||
),
|
||||
],
|
||||
"legacy": [
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.gamma",
|
||||
target_keys="LayerNorm.weight",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="LayerNorm.beta",
|
||||
target_keys="LayerNorm.bias",
|
||||
),
|
||||
],
|
||||
}
|
||||
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="weight_g",
|
||||
target_keys="parametrizations.weight.original0",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="weight_v",
|
||||
target_keys="parametrizations.weight.original1",
|
||||
),
|
||||
]
|
||||
else:
|
||||
mapping["legacy"] += [
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original0",
|
||||
target_keys="weight_g",
|
||||
),
|
||||
WeightConverter(
|
||||
source_keys="parametrizations.weight.original1",
|
||||
target_keys="weight_v",
|
||||
),
|
||||
]
|
||||
|
||||
mapping["phimoe"] = mapping["mixtral"].copy()
|
||||
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
|
||||
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
|
||||
mapping["dot1"] = mapping["qwen2_moe"].copy()
|
||||
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["jamba"] = mapping["qwen2_moe"].copy()
|
||||
mapping["lfm2_moe"] = mapping["mixtral"].copy()
|
||||
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
|
||||
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
|
||||
mapping["minimax"] = mapping["mixtral"].copy()
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
_checkpoint_conversion_mapping_cache = None
|
||||
|
||||
|
||||
def get_checkpoint_conversion_mapping(model_type):
|
||||
global _checkpoint_conversion_mapping_cache
|
||||
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
|
||||
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
|
||||
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type, None))
|
||||
732
src/transformers/core_model_loading.py
Normal file
732
src/transformers/core_model_loading.py
Normal file
@ -0,0 +1,732 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Core helpers for loading model checkpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import MutableMapping, MutableSet, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, DTensor, Replicate, TensorParallelLayer
|
||||
from .utils import is_torch_greater_or_equal, logging
|
||||
|
||||
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
|
||||
if _is_dtensor_available:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .quantizers import HfQuantizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
str_to_torch_dtype = {
|
||||
"BOOL": torch.bool,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I32": torch.int32,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I64": torch.int64,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
|
||||
"""
|
||||
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
|
||||
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
|
||||
"""
|
||||
star = r"(\d+)" if digits_only else r"(.+)"
|
||||
return glob.replace(r"\*", star)
|
||||
|
||||
|
||||
def build_glob_alt(
|
||||
globs: list[str],
|
||||
) -> tuple[re.Pattern, dict[str, str]]:
|
||||
r"""
|
||||
Build one compiled regex alternation with a named group per glob. This allows to run a single
|
||||
re.match and get the correct group name to finally get which pattern matched.
|
||||
Returns (compiled_regex, name->glob map).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
|
||||
>>> print(reg)
|
||||
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
|
||||
>>> print(map_)
|
||||
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
|
||||
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
|
||||
>>> print(match_.lastgroup)
|
||||
'g0'
|
||||
>>> print(map_[match_.lastgroup])
|
||||
mlp.*.w1
|
||||
```
|
||||
"""
|
||||
name_map: dict[str, str] = {}
|
||||
parts: list[str] = []
|
||||
|
||||
for i, g in enumerate(globs):
|
||||
name = f"g{i}"
|
||||
name_map[name] = g
|
||||
pat_src = _glob_to_regex_src(g)
|
||||
prefix_src = ""
|
||||
if pat_src.startswith("*"):
|
||||
prefix_src = "."
|
||||
elif not pat_src.startswith(r"\^") and not pat_src.startswith(r".*"):
|
||||
prefix_src = ".*"
|
||||
|
||||
parts.append(f"(?P<{name}>{prefix_src}{pat_src}.*)")
|
||||
|
||||
alt_src = "|".join(parts).replace("\\^", "^").replace("\\.", r"\.")
|
||||
try:
|
||||
reg = re.compile(alt_src)
|
||||
except re.error as e:
|
||||
logger.error(f"Error compiling regex for alternation: {alt_src}")
|
||||
raise e
|
||||
|
||||
return reg, name_map
|
||||
|
||||
|
||||
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Match the key against the alternation; return the original glob string that matched.
|
||||
"""
|
||||
m = alt.match(key)
|
||||
if not m:
|
||||
return None
|
||||
return name_map.get(m.lastgroup)
|
||||
|
||||
|
||||
class ConversionOps:
|
||||
"""Base class for weight conversion operations."""
|
||||
|
||||
# The inverse operation class, will be used when saving the checkpoint
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
@abstractmethod
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Chunk(ConversionOps):
|
||||
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
|
||||
if chunks is None and sizes is None:
|
||||
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
|
||||
if chunks is not None and chunks <= 0:
|
||||
raise ValueError("`chunks` must be a strictly positive integer.")
|
||||
self.dim = dim
|
||||
self.chunks = chunks
|
||||
self.sizes = list(sizes) if sizes is not None else None
|
||||
self.reverse_op = Concatenate
|
||||
|
||||
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
|
||||
# chunk requires a single tensor input
|
||||
if len(value) != 1 or len(value[0]) != 1:
|
||||
raise ValueError("Chunk operation requires a single tensor input.")
|
||||
return list(torch.chunk(value[0][0], self.chunks, dim=self.dim))
|
||||
|
||||
|
||||
class Concatenate(ConversionOps):
|
||||
"""Concatenate tensors along `dim` using a reusable buffer."""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
self.dim = dim
|
||||
self.reverse_op = Chunk
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
|
||||
if isinstance(value[0], list):
|
||||
value = [v[0] for v in value]
|
||||
tensors = value
|
||||
if not tensors:
|
||||
raise ValueError("Fuse requires at least one tensor to concatenate.")
|
||||
|
||||
return torch.cat(tuple(tensors), dim=self.dim)
|
||||
|
||||
|
||||
class MergeModulelist(Concatenate):
|
||||
"""
|
||||
Merge a list of tensors into a single tensor along the first dimension.
|
||||
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 0):
|
||||
super().__init__(dim=dim)
|
||||
self.reverse_op = SplitModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
|
||||
merged = []
|
||||
for group in value:
|
||||
if not isinstance(group, Sequence) or len(group) == 0:
|
||||
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
|
||||
group = [k for k in group if k.ndim]
|
||||
merged.append(torch.stack(group, dim=self.dim))
|
||||
return merged
|
||||
|
||||
|
||||
class SplitModulelist(ConversionOps):
|
||||
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
|
||||
|
||||
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
|
||||
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
|
||||
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
|
||||
self.sizes = [list(sub) for sub in sizes]
|
||||
self.dim = dim
|
||||
self.reverse_op = MergeModulelist
|
||||
|
||||
@torch.no_grad
|
||||
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
|
||||
if not isinstance(value, Sequence):
|
||||
raise TypeError("SplitModulelist expects a sequence of tensors.")
|
||||
if len(value) != len(self.sizes):
|
||||
raise ValueError("Number of tensors does not match the provided split specifications.")
|
||||
|
||||
result: list[list[torch.Tensor]] = []
|
||||
for tensor, split_sizes in zip(value, self.sizes):
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
|
||||
splits = torch.split(tensor, split_sizes, dim=self.dim)
|
||||
result.append(list(splits))
|
||||
return result
|
||||
|
||||
|
||||
class PermuteForRope(ConversionOps):
|
||||
"""
|
||||
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
dim1, dim2 = tensor.shape
|
||||
n_heads = self.config.getattr("num_attention_heads", 1)
|
||||
|
||||
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
|
||||
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
|
||||
return tensor
|
||||
|
||||
@torch.no_grad
|
||||
def convert(
|
||||
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
|
||||
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
|
||||
self.config = config
|
||||
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
|
||||
return out
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WeightConverter:
|
||||
r"""
|
||||
A weight convert that acts on a pattern of source keys.
|
||||
The keys need to be collected based on the target keys.
|
||||
|
||||
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
|
||||
`model.layers.*.experts.*` -> it will act on all of them
|
||||
{"model.layers.*.experts.*": []}
|
||||
but
|
||||
`experts.*.mlp` will be layer specific.
|
||||
{"model.layers.1.experts.*": [], }
|
||||
- source_keys: str | list[str] (wildcards '*' match digits)
|
||||
- target_keys: str | list[str] | None
|
||||
- distributed_operation / operations / quantization_operations are ALWAYS lists.
|
||||
|
||||
TODO: for BNB we need to collect model.weight.quant_state_keys
|
||||
"""
|
||||
|
||||
source_keys: Union[str, list[str]]
|
||||
target_keys: Optional[Union[str, list[str]]] = None
|
||||
operations: list[ConversionOps] = field(default_factory=list, repr=False)
|
||||
|
||||
distributed_operation: Optional[TensorParallelLayer] = None
|
||||
quantization_operation: Optional[ConversionOps] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.source_keys, list):
|
||||
self.source_keys = [self.source_keys]
|
||||
targets_were_none = False
|
||||
if not isinstance(self.target_keys, list):
|
||||
if self.target_keys is None:
|
||||
self.target_keys = list(self.source_keys)
|
||||
targets_were_none = True
|
||||
else:
|
||||
self.target_keys = [self.target_keys]
|
||||
|
||||
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
|
||||
raise ValueError(
|
||||
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConversionEntry:
|
||||
weight_converter: WeightConverter
|
||||
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
|
||||
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
|
||||
|
||||
|
||||
# Factory function to create LoadedParameter subclasses dynamically
|
||||
def get_loaded_parameter_class(base_cls):
|
||||
"""
|
||||
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
|
||||
Returns a new class that combines the base_cls with LoadedParameterMixin
|
||||
|
||||
"""
|
||||
|
||||
class LoadedParam(base_cls):
|
||||
_inplace_methods = [
|
||||
"add_",
|
||||
"mul_",
|
||||
"clamp_",
|
||||
"zero_",
|
||||
"fill_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
"copy_",
|
||||
"erfinv_",
|
||||
"log_",
|
||||
"__getitem__",
|
||||
"neg_",
|
||||
"exp_",
|
||||
"sub_",
|
||||
]
|
||||
|
||||
def __new__(cls, from_existing, **kwargs):
|
||||
if isinstance(from_existing, torch.nn.Parameter):
|
||||
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
|
||||
else:
|
||||
inst = super().__new__(cls, from_existing)
|
||||
# we store the original object to get it back later on
|
||||
inst._original = from_existing
|
||||
# Explicitly override all in-place methods per instance
|
||||
for method_name in inst._inplace_methods:
|
||||
setattr(inst, method_name, MethodType(inst._skip, inst))
|
||||
|
||||
return inst
|
||||
|
||||
def _skip(self, *args, **kwargs):
|
||||
"""Helper to skip in-place operations."""
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"LoadedParameter(data={self.data})"
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return super().data
|
||||
|
||||
@data.setter
|
||||
def data(self, new):
|
||||
pass
|
||||
|
||||
def __lt__(self, other):
|
||||
return torch.Tensor.__lt__(self, other)
|
||||
|
||||
def __le__(self, other):
|
||||
return torch.Tensor.__le__(self, other)
|
||||
|
||||
def __gt__(self, other):
|
||||
return torch.Tensor.__gt__(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
return torch.Tensor.__ge__(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
return torch.Tensor.__eq__(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return torch.Tensor.__ne__(self, other)
|
||||
|
||||
def __iadd__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __isub__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imul__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imatmul__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __itruediv__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ifloordiv__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __imod__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ipow__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __iand__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ior__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ixor__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __ilshift__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __irshift__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
return LoadedParam
|
||||
|
||||
|
||||
def _materialize_copy(tensor, dtype=None):
|
||||
tensor = tensor[...]
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
|
||||
def _job():
|
||||
return _materialize_copy(tensor, dtype)
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
|
||||
def _job():
|
||||
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
|
||||
|
||||
return thread_pool.submit(_job)
|
||||
|
||||
|
||||
def dot_natural_key(s: str):
|
||||
parts = s.split(".")
|
||||
for i, p in enumerate(parts):
|
||||
# whole-segment digits -> int; otherwise leave as str
|
||||
if p.isdigit():
|
||||
parts[i] = int(p)
|
||||
return parts
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_to_misc(
|
||||
layer_name: str,
|
||||
misc: MutableMapping[str, str],
|
||||
extras: Any = None,
|
||||
op: Union[list[ConversionOps], ConversionOps, None] = None,
|
||||
):
|
||||
# A simple helper to handle errors with contextual messages.
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
|
||||
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
|
||||
if curr_op is None:
|
||||
return None
|
||||
if isinstance(curr_op, (list, tuple, set)):
|
||||
names = [o.__class__.__name__ for o in curr_op if o is not None]
|
||||
if not names:
|
||||
return None
|
||||
return ", ".join(names)
|
||||
return curr_op.__class__.__name__
|
||||
|
||||
op_name = _format_op_name(op)
|
||||
if isinstance(extras, tuple) and len(extras) == 2:
|
||||
values, target_keys = extras
|
||||
descriptor = f"{op_name} " if op_name else ""
|
||||
misc[layer_name] = (
|
||||
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
|
||||
)
|
||||
elif isinstance(extras, str):
|
||||
suffix = f" via {op_name}" if op_name else ""
|
||||
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
|
||||
elif extras is None and op_name:
|
||||
misc[layer_name] = f"{op_name}: {e}"
|
||||
else:
|
||||
misc[layer_name] = f"{extras} |Error: {e}"
|
||||
raise SkipLayer()
|
||||
|
||||
|
||||
def set_param_for_module(
|
||||
model: PreTrainedModel,
|
||||
layer_name: str,
|
||||
param_value: torch.Tensor,
|
||||
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
|
||||
missing_keys: MutableSet[str],
|
||||
misc: MutableMapping[str, Any],
|
||||
distributed_operation: Optional[TensorParallelLayer],
|
||||
):
|
||||
with log_to_misc(layer_name, misc, layer_name):
|
||||
module_path, _, param_name = layer_name.rpartition(".")
|
||||
module_obj = model.get_submodule(module_path) if module_path else model
|
||||
param_value = param_value[0] if isinstance(param_value, list) else param_value[...]
|
||||
ref = getattr(module_obj, param_name)
|
||||
|
||||
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
|
||||
if not isinstance(param_value, torch.nn.Parameter):
|
||||
if distributed_operation is not None:
|
||||
param_value = DTensor.from_local(
|
||||
param_value,
|
||||
distributed_operation.device_mesh,
|
||||
getattr(distributed_operation, "shard", Replicate()),
|
||||
run_check=False,
|
||||
shape=ref.size(),
|
||||
stride=ref.stride(),
|
||||
)
|
||||
if not use_dtensor:
|
||||
# we convert to local
|
||||
param_value = param_value.to_local()
|
||||
if param_name not in module_obj._buffers:
|
||||
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
|
||||
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)
|
||||
|
||||
# Remove from missing keys (it's either mismatched, or all good)
|
||||
missing_keys.discard(layer_name)
|
||||
if ref is not None and ref.shape != param_value.shape:
|
||||
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
|
||||
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
|
||||
else:
|
||||
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
|
||||
setattr(module_obj, param_name, param_value)
|
||||
|
||||
|
||||
class SkipLayer(Exception):
|
||||
"""Control-flow sentinel: abort processing of the current layer only."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def convert_and_load_state_dict_in_model(
|
||||
model: PreTrainedModel,
|
||||
state_dict: dict[str, Any],
|
||||
weight_mapping: dict[str, WeightConverter] | None,
|
||||
tp_plan: dict[str, str] | None,
|
||||
quantizer: HfQuantizer | None,
|
||||
dtype: torch.dtype | None = None,
|
||||
device_map: dict | None = None,
|
||||
dtype_plan: dict | None = None,
|
||||
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
|
||||
):
|
||||
"""
|
||||
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
|
||||
collecting tensors per *layer instance* (the concrete indices captured from '*').
|
||||
"""
|
||||
|
||||
prefix = model.base_model_prefix
|
||||
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
|
||||
device_map = device_map or {} # {exact_target_key: device}
|
||||
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
|
||||
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
|
||||
meta_model_state_dict = model.state_dict()
|
||||
missing_keys = set(meta_model_state_dict.keys())
|
||||
|
||||
misc = {}
|
||||
mismatch_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Global thread_pool
|
||||
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
|
||||
|
||||
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
|
||||
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
|
||||
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
|
||||
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
|
||||
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
|
||||
|
||||
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
|
||||
# 1. Create the conversion entries
|
||||
by_conversion_pattern: dict[str, ConversionEntry] = {}
|
||||
for original_key, tensor in state_dict:
|
||||
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
|
||||
if matched_pattern is not None:
|
||||
converter = source_to_target[matched_pattern] # TODO make sure its the ref
|
||||
sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key)
|
||||
entry_key = "|".join(converter.target_keys)
|
||||
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
|
||||
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
|
||||
converter_key = sub_with_extractor(matched_pattern)
|
||||
else:
|
||||
converter = WeightConverter(original_key)
|
||||
converter_key = entry_key = target_key = original_key
|
||||
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
|
||||
|
||||
_dtype = dtype
|
||||
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
|
||||
for t in target_key.split("|"):
|
||||
if t.startswith(prefix) and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", t, count=1)) is not None:
|
||||
t = re.sub(f"^{prefix}.", "", t, count=1)
|
||||
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
|
||||
t = f"{prefix}.{t}"
|
||||
new_target_key.append(t)
|
||||
empty_param = meta_model_state_dict.get(t)
|
||||
# If it does not exist, it's unexpected
|
||||
if empty_param is None:
|
||||
unexpected_keys.add(t)
|
||||
continue
|
||||
|
||||
if quantizer is not None and quantizer.param_needs_quantization(model, t):
|
||||
if quantizer.__class__.__name__ == "FineGrainedFP8HfQuantizer":
|
||||
from .integrations.finegrained_fp8 import Fp8Quantize
|
||||
|
||||
converter.quantization_operation = Fp8Quantize() # TODO support other methods
|
||||
else:
|
||||
raise ValueError("This quantization method is gonna be supported SOOOON")
|
||||
else:
|
||||
_dtype = dtype
|
||||
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
|
||||
if matched_dtype_pattern is not None:
|
||||
_dtype = dtype_plan[matched_dtype_pattern]
|
||||
elif empty_param.dtype != _dtype:
|
||||
_dtype = empty_param.dtype
|
||||
|
||||
first_target_key = new_target_key[0]
|
||||
target_key = "|".join(new_target_key)
|
||||
|
||||
future = None
|
||||
if device_mesh:
|
||||
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
|
||||
empty_param = meta_model_state_dict.get(first_target_key)
|
||||
if getattr(converter, "distributed_operation", {}) is None:
|
||||
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
|
||||
converter.distributed_operation = tp_layer(
|
||||
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
|
||||
)
|
||||
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
|
||||
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
|
||||
future = spawn_tp_materialize(
|
||||
thread_pool,
|
||||
tensor,
|
||||
_dtype,
|
||||
converter.distributed_operation,
|
||||
shard_index,
|
||||
)
|
||||
|
||||
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
|
||||
future = spawn_materialize(thread_pool, tensor, _dtype)
|
||||
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
|
||||
|
||||
# 2. Actually convert the ckpt
|
||||
inverse_converters = {}
|
||||
keys = list(by_conversion_pattern.keys())
|
||||
|
||||
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
|
||||
for key in keys[::-1]: # revert to process simple keys first
|
||||
group = by_conversion_pattern.pop(key)
|
||||
converter = group.weight_converter
|
||||
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
|
||||
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({"Materializing param": layer_name})
|
||||
pbar.refresh()
|
||||
concrete_target_keys = layer_name.split("|")
|
||||
try:
|
||||
if bool(set(concrete_target_keys) - unexpected_keys):
|
||||
with log_to_misc(layer_name, misc):
|
||||
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
|
||||
|
||||
for op in operations:
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
values = op.convert(values, model.config)
|
||||
|
||||
values = [values] if not isinstance(values, list) else values
|
||||
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
|
||||
realized_value = {
|
||||
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
|
||||
}
|
||||
|
||||
for k in list(realized_value.keys()).copy():
|
||||
if op := converter.quantization_operation:
|
||||
with log_to_misc(layer_name, misc, op=op):
|
||||
realized_value.update(
|
||||
op.convert(
|
||||
{k: realized_value.pop(k)}, quant_config=quantizer.quantization_config
|
||||
)
|
||||
)
|
||||
|
||||
for k, output_value in realized_value.items():
|
||||
for src in converter.source_keys: # what should happen to k when we meet k at saving
|
||||
inverse_converters[k] = {src: converter}
|
||||
set_param_for_module(
|
||||
model,
|
||||
k,
|
||||
output_value,
|
||||
mismatch_keys,
|
||||
missing_keys,
|
||||
misc,
|
||||
converter.distributed_operation,
|
||||
)
|
||||
|
||||
except SkipLayer:
|
||||
continue
|
||||
del group
|
||||
|
||||
model.inverse_converters = inverse_converters
|
||||
thread_pool.shutdown(wait=False)
|
||||
return missing_keys, unexpected_keys, mismatch_keys, misc
|
||||
|
||||
|
||||
# TODO this is not done yet!
|
||||
def revert_weight_conversion(model, state_dict):
|
||||
mapping = getattr(model, "_checkpoint_conversion_mapping", {}) # IDK why but setting this will fail all llava.
|
||||
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
|
||||
original_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
for pattern, inverse_converter in reverse_key_mapping:
|
||||
# TODO FIXME you name it
|
||||
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
|
||||
replacement = re.sub(r"\(.*\)", "", replacement)
|
||||
key, n_replace = re.subn(pattern, replacement, key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
break
|
||||
original_state_dict[key] = value
|
||||
state_dict = original_state_dict
|
||||
return state_dict
|
||||
@ -411,7 +411,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
|
||||
if hasattr(self, "load_custom_generate"):
|
||||
if hasattr(self, "load_custom_generate") and trust_remote_code:
|
||||
try:
|
||||
custom_generate = self.load_custom_generate(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
|
||||
@ -1635,7 +1635,12 @@ class GenerationMixin(ContinuousMixin):
|
||||
|
||||
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
|
||||
if (
|
||||
value is not None
|
||||
and key not in model_args
|
||||
and key not in TransformersKwargs.__optional_keys__
|
||||
and key != "debug_io"
|
||||
):
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
|
||||
@ -383,10 +383,11 @@ class BayesianDetectorModel(PreTrainedModel):
|
||||
)
|
||||
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
module.weight.normal_(mean=0.0, std=0.02)
|
||||
|
||||
def _compute_posterior(
|
||||
self,
|
||||
|
||||
@ -512,10 +512,8 @@ def accelerate_disk_offload(
|
||||
checkpoint_files,
|
||||
device_map,
|
||||
checkpoint_keys,
|
||||
key_renaming_mapping,
|
||||
sharded_metadata,
|
||||
dtype,
|
||||
reverse_key_renaming_mapping,
|
||||
):
|
||||
disk_only_shard_files = []
|
||||
if disk_offload_folder is not None:
|
||||
@ -534,19 +532,13 @@ def accelerate_disk_offload(
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
weight_map = {
|
||||
key_renaming_mapping[k]: v
|
||||
for k, v in sharded_metadata["weight_map"].items()
|
||||
if k in key_renaming_mapping
|
||||
}
|
||||
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
|
||||
# Find potential checkpoints containing only offloaded weights
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
|
||||
disk_offload_index = {
|
||||
name: {
|
||||
"safetensors_file": file,
|
||||
"weight_name": reverse_key_renaming_mapping[name],
|
||||
"weight_name": name,
|
||||
"dtype": str_dtype,
|
||||
}
|
||||
for name, file in weight_map.items()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
from ..utils import (
|
||||
@ -24,7 +23,6 @@ if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -151,52 +149,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
||||
return model
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model and tie the weights, then
|
||||
# check if it contains tied weights
|
||||
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model.tie_weights()
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
tied_keys = sum(tied_params, [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision
|
||||
if not has_tied_params:
|
||||
output_emb = model.get_output_embeddings()
|
||||
if output_emb is not None:
|
||||
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
|
||||
return list_last_module
|
||||
|
||||
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = list(set(tied_keys)) + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
|
||||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
|
||||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
|
||||
"""
|
||||
|
||||
@ -13,8 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from ..core_model_loading import ConversionOps
|
||||
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
|
||||
|
||||
|
||||
@ -30,6 +33,18 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
try:
|
||||
_FP8_DTYPE = torch.float8_e4m3fn
|
||||
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
|
||||
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
|
||||
_FP8_IS_INT = False
|
||||
except AttributeError:
|
||||
_FP8_DTYPE = torch.int8
|
||||
_FP8_MIN, _FP8_MAX = -127, 127
|
||||
_FP8_IS_INT = True
|
||||
logger.warning_once(
|
||||
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
|
||||
)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
|
||||
@ -332,6 +347,12 @@ class FP8Linear(nn.Linear):
|
||||
if self.weight.element_size() > 1:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
if isinstance(self.weight, torch.distributed.tensor.DTensor):
|
||||
weight = self.weight._local_tensor.contiguous()
|
||||
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
|
||||
else:
|
||||
weight = self.weight.contiguous()
|
||||
scale_inv = self.weight_scale_inv.contiguous()
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
@ -339,9 +360,9 @@ class FP8Linear(nn.Linear):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
self.weight,
|
||||
weight,
|
||||
scale,
|
||||
self.weight_scale_inv,
|
||||
scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
@ -350,9 +371,124 @@ class FP8Linear(nn.Linear):
|
||||
torch_accelerator_module.synchronize()
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
output = torch.nan_to_num(output, nan=0.0)
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
def _ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
class FP8Expert(nn.Module):
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
def __init__(self, config, block_size, device):
|
||||
super().__init__()
|
||||
|
||||
from ..activations import ACT2FN
|
||||
|
||||
self.block_size = block_size
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.intermediate_dim = config.intermediate_size
|
||||
|
||||
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
|
||||
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
|
||||
|
||||
self.gate_up_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
self.down_proj = nn.Parameter(
|
||||
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
|
||||
)
|
||||
|
||||
# Create inverse scale tiles only when using 1-byte types (fp8)
|
||||
if self.gate_up_proj.element_size() == 1:
|
||||
bo, bi = self.block_size
|
||||
|
||||
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
|
||||
gu_scale_o = _ceil_div(Wg_out, bo)
|
||||
gu_scale_i = _ceil_div(Wg_in, bi)
|
||||
self.gate_up_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
|
||||
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
|
||||
dp_scale_o = _ceil_div(Wd_out, bo)
|
||||
dp_scale_i = _ceil_div(Wd_in, bi)
|
||||
self.down_proj_scales_inv = nn.Parameter(
|
||||
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
|
||||
)
|
||||
else:
|
||||
# Match FP8Linear behavior when not using 1-byte weights
|
||||
self.register_parameter("gate_up_proj_scale_inv", None)
|
||||
self.register_parameter("down_proj_scale_inv", None)
|
||||
|
||||
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
|
||||
self.register_parameter("gate_up_bias", None)
|
||||
self.register_parameter("down_bias", None)
|
||||
|
||||
# Activation used in the MLP (same as your config / ACT2FN)
|
||||
# Keep a handle here; actual usage happens in forward of your MoE block
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
num_experts = top_k_weights.shape[1]
|
||||
with torch.no_grad():
|
||||
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
|
||||
expert_mask = expert_mask.permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
for expert_idx in expert_hit:
|
||||
expert_idx = expert_idx[0]
|
||||
if expert_idx == num_experts:
|
||||
continue
|
||||
_, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current_state = hidden_states.index_select(0, token_idx)
|
||||
gate, up = self.linear(
|
||||
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
|
||||
).chunk(2, dim=-1)
|
||||
current_hidden_states = self.act_fn(gate) * up
|
||||
current_hidden_states = self.linear(
|
||||
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
|
||||
)
|
||||
|
||||
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
|
||||
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
|
||||
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(input, weight, None)
|
||||
else:
|
||||
# Context manager used to switch among the available accelerators
|
||||
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
|
||||
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
||||
with torch_accelerator_module.device(input.device):
|
||||
qinput, scale = act_quant(input, self.block_size[1])
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
qinput,
|
||||
weight,
|
||||
scale,
|
||||
weight_scale_inv,
|
||||
self.block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
|
||||
# preceding operations are ready before proceeding
|
||||
torch_accelerator_module.synchronize()
|
||||
return output.to(dtype=input.dtype)
|
||||
|
||||
|
||||
# TODO: we do need this.... but not recursive...
|
||||
def _replace_with_fp8_linear(
|
||||
model,
|
||||
tp_plan=None,
|
||||
@ -361,40 +497,48 @@ def _replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""Replace Linear layers with FP8Linear."""
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
iterator = list(model.named_parameters()).copy()
|
||||
for name, empty_tensor in iterator:
|
||||
current_key_name = name
|
||||
name = name.rsplit(".", 1)[0] if "." in name else name
|
||||
module = model.get_submodule(name)
|
||||
|
||||
for name, module in model.named_children():
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
model._modules[name] = FP8Linear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
|
||||
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
|
||||
with init_empty_weights():
|
||||
if (
|
||||
"gate_up_proj" in current_key_name
|
||||
or "down_proj" in current_key_name
|
||||
and "experts" in current_key_name
|
||||
): # Experts!
|
||||
in_features = empty_tensor.size(-2)
|
||||
out_features = empty_tensor.size(-1)
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Expert(
|
||||
config=model.config,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
device=empty_tensor.device,
|
||||
),
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_fp8_linear(
|
||||
module,
|
||||
tp_plan,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
|
||||
current_key_name.pop(-1)
|
||||
elif isinstance(module, nn.Linear):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
model.set_submodule(
|
||||
name,
|
||||
FP8Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
activation_scheme=quantization_config.activation_scheme,
|
||||
block_size=quantization_config.weight_block_size,
|
||||
),
|
||||
)
|
||||
has_been_replaced = True
|
||||
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||
|
||||
return model, has_been_replaced
|
||||
|
||||
@ -405,7 +549,7 @@ def replace_with_fp8_linear(
|
||||
quantization_config=None,
|
||||
):
|
||||
"""Helper function to replace model layers with FP8 versions."""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
modules_to_not_convert += ["lm_head"]
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
@ -424,3 +568,133 @@ def replace_with_fp8_linear(
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QuantizationOp(ConversionOps):
|
||||
"""Base class for quantization operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Fp8Quantize(QuantizationOp):
|
||||
"""
|
||||
A quantization operation that creates two tensors, weight and scale out of a weight.
|
||||
"""
|
||||
|
||||
reverse_op: type[ConversionOps]
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Dequantize
|
||||
|
||||
def convert(self, input_dict: torch.Tensor, *, quant_config: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
# Unpack single key/value (value may be wrapped in a list)
|
||||
target_keys, value = tuple(input_dict.items())[0]
|
||||
value = value[0] if isinstance(value, list) else value
|
||||
|
||||
# Resolve block size (support dict-like or attr-like quant_config)
|
||||
block_size = None
|
||||
if quant_config is not None:
|
||||
if isinstance(quant_config, dict):
|
||||
block_size = quant_config.get("weight_block_size")
|
||||
else:
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (value.shape[-2], value.shape[-1])
|
||||
|
||||
block_m, block_n = block_size
|
||||
rows, cols = value.shape[-2], value.shape[-1]
|
||||
|
||||
# Enforce exact tiling like your original
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
|
||||
)
|
||||
|
||||
# Leading dims can be empty (2D) or include num_experts/... (3D+)
|
||||
leading_shape = value.shape[:-2]
|
||||
rows_tiles = rows // block_m
|
||||
cols_tiles = cols // block_n
|
||||
|
||||
original_shape = value.shape
|
||||
value_fp32 = value.to(torch.float32)
|
||||
|
||||
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
|
||||
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
|
||||
|
||||
# Per-tile max-abs over the block dims
|
||||
# dims: block_m is at -3, block_n is at -1 after the reshape
|
||||
max_abs = reshaped.abs().amax(dim=(-3, -1))
|
||||
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
|
||||
|
||||
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
|
||||
scales = _FP8_MAX / safe_max_abs
|
||||
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
|
||||
|
||||
# Broadcast scales back over the block dims and quantize
|
||||
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
|
||||
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
|
||||
scaled = reshaped * scales_broadcast
|
||||
|
||||
if _FP8_IS_INT:
|
||||
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
else:
|
||||
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
|
||||
|
||||
quantized = quantized.reshape(original_shape)
|
||||
|
||||
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
|
||||
if target_keys.endswith("weight"):
|
||||
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
|
||||
else:
|
||||
scale_key = target_keys + "_scales_inv"
|
||||
|
||||
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
|
||||
return {
|
||||
target_keys: quantized,
|
||||
scale_key: inv_scales,
|
||||
}
|
||||
|
||||
|
||||
class Fp8Dequantize(QuantizationOp):
|
||||
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
|
||||
|
||||
def __init__(self, block_size: Optional[tuple[int, int]] = None):
|
||||
self.block_size = block_size
|
||||
self.reverse_op = Fp8Quantize
|
||||
|
||||
def convert(
|
||||
self,
|
||||
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
|
||||
*,
|
||||
context: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
if isinstance(value, dict):
|
||||
tensors = list(value.values())
|
||||
else:
|
||||
tensors = list(value) if isinstance(value, Sequence) else [value]
|
||||
if len(tensors) != 2:
|
||||
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
|
||||
quantized, scales = tensors
|
||||
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
|
||||
raise TypeError("Fp8Dequantize expects tensors as inputs.")
|
||||
|
||||
quantized_fp32 = quantized.to(torch.float32)
|
||||
rows, cols = quantized_fp32.shape[-2:]
|
||||
block_size = self.block_size
|
||||
if block_size is None:
|
||||
quant_config = context.get("quantization_config")
|
||||
block_size = getattr(quant_config, "weight_block_size", None)
|
||||
if block_size is None:
|
||||
block_size = (rows, cols)
|
||||
block_m, block_n = block_size
|
||||
if rows % block_m != 0 or cols % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
|
||||
)
|
||||
|
||||
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
|
||||
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
|
||||
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
|
||||
dequantized = reshaped * expanded_scales
|
||||
return dequantized.reshape(quantized_fp32.shape)
|
||||
|
||||
@ -236,7 +236,7 @@ class PeftAdapterMixin:
|
||||
**adapter_kwargs,
|
||||
)
|
||||
peft_config.inference_mode = not is_trainable
|
||||
|
||||
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ import operator
|
||||
import os
|
||||
import re
|
||||
from functools import partial, reduce
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -140,6 +141,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def replace_layer_number_by_wildcard(name: str) -> str:
|
||||
"""
|
||||
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
|
||||
a dot (`.`) and the end of the string.
|
||||
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
|
||||
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
|
||||
"""
|
||||
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
|
||||
|
||||
|
||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
|
||||
"""
|
||||
Get the TP style for a parameter from the TP plan.
|
||||
@ -150,11 +161,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
|
||||
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
|
||||
not parent classes for `post_init` calls
|
||||
"""
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
return tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
|
||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
||||
return tp_plan[module_name]
|
||||
return None
|
||||
|
||||
|
||||
@ -306,7 +317,7 @@ def repack_weights(
|
||||
return final_ordered_tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
@ -358,32 +369,57 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
rank (int): Global rank of the current process/device.
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
param_dim = empty_param.ndim
|
||||
# Flatten the mesh to get the total number of devices
|
||||
mesh_shape = device_mesh.shape
|
||||
world_size = reduce(operator.mul, mesh_shape)
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
|
||||
dim = 0
|
||||
|
||||
shard_size = math.ceil(empty_param.size(dim) / world_size)
|
||||
start = rank * shard_size
|
||||
end = min(start + shard_size, empty_param.size(dim))
|
||||
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
# we have the full tensor not 1 part of it.
|
||||
# in that case, we just assume that the weight was properly saved
|
||||
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
|
||||
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
|
||||
# here we take care of potential chunking / layer split / layer chunking.
|
||||
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
|
||||
# actually we still shard dim=0 does not change
|
||||
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
|
||||
# tensor on a certain device (with the input tensor_index)
|
||||
dimensions = param.get_shape()
|
||||
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
if start < empty_param.shape[dim]:
|
||||
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
|
||||
# special case we don't "shard" just send this entire tensor to the correct rank.
|
||||
if start <= tensor_idx < end:
|
||||
# this tensor does need to be materialized on this device:
|
||||
return param[:]
|
||||
else:
|
||||
return torch.empty([], dtype=torch.int64, device=rank)
|
||||
|
||||
slice_indices = [slice(None)] * len(param.get_shape())
|
||||
|
||||
if start < param.get_shape()[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
param = param[tuple(slice_indices)]
|
||||
if isinstance(param, list): # TODO handle the modulelist case!
|
||||
param = [p[:] for p in param]
|
||||
return param
|
||||
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -410,6 +446,19 @@ class TensorParallelLayer:
|
||||
"""
|
||||
|
||||
use_dtensor = True
|
||||
device_mesh = None
|
||||
rank = None
|
||||
|
||||
# Used to compare the shape of the original tensor
|
||||
empty_param = None
|
||||
|
||||
# Used to init the corresponding DTensor
|
||||
shard = None
|
||||
|
||||
def __init__(self, device_mesh=None, rank=None, empty_param=None):
|
||||
self.rank = rank
|
||||
self.device_mesh = device_mesh
|
||||
self.empty_param = empty_param
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
|
||||
@ -439,12 +488,12 @@ class GatherParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = output_layouts
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -465,6 +514,21 @@ class GatherParallel(TensorParallelLayer):
|
||||
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
shard = [Replicate()]
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
distribute_module(
|
||||
module,
|
||||
@ -493,6 +557,23 @@ class IsolatedParallel(TensorParallelLayer):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
return outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
mesh = device_mesh or self.device_mesh
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
if mesh is not None:
|
||||
parameter = parameter / mesh.size()
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
@ -515,8 +596,8 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
|
||||
"""
|
||||
|
||||
def __init__(self, *, use_dtensor=True, use_local_output=True):
|
||||
super().__init__()
|
||||
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -537,12 +618,33 @@ class ReplicateParallel(TensorParallelLayer):
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
||||
return param
|
||||
parameter, shard = self.shard_tensor(
|
||||
param,
|
||||
param_type=param_type,
|
||||
param_casting_dtype=param_casting_dtype,
|
||||
to_contiguous=to_contiguous,
|
||||
rank=rank,
|
||||
device_mesh=device_mesh,
|
||||
)
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return parameter
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
@ -552,13 +654,13 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = (output_layouts or Shard(-1),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
@ -578,18 +680,34 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
|
||||
return input_tensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = self.rank
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
@ -608,6 +726,21 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedColwiseParallel(ColwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -642,18 +775,41 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Placement | None = None,
|
||||
output_layouts: Placement | None = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (input_layouts or Shard(-1),)
|
||||
self.output_layouts = (output_layouts or Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
if param_type == "bias":
|
||||
shard = [Replicate()]
|
||||
parameter = param[...]
|
||||
else:
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
|
||||
shard = [Shard(-1)]
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
@ -725,6 +881,21 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
|
||||
|
||||
class PackedRowwiseParallel(RowwiseParallel):
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
device_mesh = device_mesh or self.device_mesh
|
||||
empty_param = self.empty_param
|
||||
rank = rank if rank is not None else self.rank
|
||||
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
@ -783,8 +954,8 @@ class SequenceParallel(TensorParallelLayer):
|
||||
to ensure that they are replicated.
|
||||
"""
|
||||
|
||||
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
|
||||
super().__init__()
|
||||
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Shard(1),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
@ -793,6 +964,21 @@ class SequenceParallel(TensorParallelLayer):
|
||||
self.sequence_sharding = (Shard(sequence_dim),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
shard = [Replicate()]
|
||||
self.shard = shard
|
||||
return parameter, shard
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
@ -827,10 +1013,34 @@ class GroupedGemmParallel(TensorParallelLayer):
|
||||
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.use_dtensor = False
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
empty_param = self.empty_param
|
||||
ep_rank = self.rank
|
||||
device_mesh = self.device_mesh
|
||||
|
||||
global_num_experts = empty_param.shape[0]
|
||||
if global_num_experts % device_mesh.size() != 0:
|
||||
raise ValueError(
|
||||
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
|
||||
)
|
||||
local_num_experts = global_num_experts // device_mesh.size()
|
||||
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
ep_rank = rank
|
||||
global_num_experts = empty_param.shape[0]
|
||||
@ -851,8 +1061,8 @@ class RouterParallel(TensorParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.use_dtensor = False
|
||||
|
||||
@staticmethod
|
||||
@ -917,6 +1127,20 @@ class RouterParallel(TensorParallelLayer):
|
||||
) # masking class for one hot
|
||||
return router_scores, router_indices
|
||||
|
||||
def shard_tensor(
|
||||
self,
|
||||
param,
|
||||
param_type=None,
|
||||
param_casting_dtype=None,
|
||||
to_contiguous=None,
|
||||
rank=None,
|
||||
device_mesh=None,
|
||||
tensor_idx=None,
|
||||
):
|
||||
parameter = param[...].to(param_casting_dtype)
|
||||
self.shard = None
|
||||
return parameter, None
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# TODO: i'd like for this to be the default
|
||||
param = param[...].to(param_casting_dtype)
|
||||
@ -1059,6 +1283,9 @@ def shard_and_distribute_module(
|
||||
if current_shard_plan is not None:
|
||||
try:
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
|
||||
tp_layer.empty_param = empty_param
|
||||
tp_layer.device_mesh = device_mesh
|
||||
tp_layer.rank = rank
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
@ -1086,7 +1313,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
||||
if tp_plan is None:
|
||||
return
|
||||
|
||||
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
|
||||
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
|
||||
unsharded_layers = set(generic_keys)
|
||||
unused_rules = tp_plan
|
||||
|
||||
|
||||
@ -697,6 +697,7 @@ def _validate_yarn_parameters(
|
||||
"original_max_position_embeddings",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
"truncate",
|
||||
}
|
||||
received_keys = set(rope_parameters.keys())
|
||||
rope_type = rope_parameters["rope_type"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -406,13 +406,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -449,13 +449,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if hasattr(module, "logit_scale"):
|
||||
if isinstance(module.logit_scale, nn.Parameter):
|
||||
module.logit_scale.data.fill_(math.log(1 / 0.07))
|
||||
module.logit_scale.fill_(math.log(1 / 0.07))
|
||||
elif isinstance(module, Aimv2AttentionPoolingHead):
|
||||
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -302,21 +302,22 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AlbertAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, AlbertMLMHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -425,7 +426,10 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config: AlbertConfig):
|
||||
super().__init__(config)
|
||||
@ -525,7 +529,6 @@ class AlbertMLMHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@ -537,14 +540,6 @@ class AlbertMLMHead(nn.Module):
|
||||
|
||||
return prediction_scores
|
||||
|
||||
def _tie_weights(self) -> None:
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class AlbertSOPHead(nn.Module):
|
||||
def __init__(self, config: AlbertConfig):
|
||||
@ -561,7 +556,10 @@ class AlbertSOPHead(nn.Module):
|
||||
|
||||
@auto_docstring
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
|
||||
"predictions.decoder.bias": "predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -823,24 +823,25 @@ class AlignPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, AlignModel):
|
||||
nn.init.xavier_uniform_(module.text_projection.weight)
|
||||
module.text_projection.bias.data.zero_()
|
||||
module.temperature.data.fill_(self.config.temperature_init_value)
|
||||
module.text_projection.bias.zero_()
|
||||
module.temperature.fill_(self.config.temperature_init_value)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
|
||||
@ -770,6 +770,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_module = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
@ -797,23 +798,21 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.text_projection.weight,
|
||||
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.text_projection._is_hf_initialized = True
|
||||
nn.init.normal_(
|
||||
module.visual_projection.weight,
|
||||
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
||||
)
|
||||
module.visual_projection._is_hf_initialized = True
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
class AltCLIPVisionTransformer(nn.Module):
|
||||
|
||||
@ -106,7 +106,6 @@ class ApertusConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -429,7 +429,7 @@ class ApertusModel(ApertusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -123,7 +123,6 @@ class ApertusConfig(LlamaConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -434,7 +434,7 @@ class ArceeModel(ArceePreTrainedModel):
|
||||
|
||||
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
|
||||
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -99,15 +99,14 @@ class AriaTextConfig(PreTrainedConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `AriaTextModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -585,10 +585,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -608,6 +609,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -760,7 +762,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
@ -890,8 +892,6 @@ class AriaModelOutputWithPast(BaseModelOutputWithPast):
|
||||
"""
|
||||
)
|
||||
class AriaModel(AriaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
@ -1048,12 +1048,12 @@ class AriaModel(AriaPreTrainedModel):
|
||||
)
|
||||
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
r"^language_model.model": "model.language_model",
|
||||
r"^vision_tower": "model.vision_tower",
|
||||
r"^multi_modal_projector": "model.multi_modal_projector",
|
||||
r"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -169,6 +169,15 @@ class AriaTextConfig(LlamaConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
base_config_key = "text_config"
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1187,10 +1196,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
|
||||
"attentions": AriaTextAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, AriaGroupedExpertsGemm):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
@ -1199,6 +1209,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
|
||||
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
|
||||
_supports_attention_backend = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
PreTrainedModel._init_weights(self, module)
|
||||
if isinstance(module, AriaProjector):
|
||||
@ -1216,7 +1227,7 @@ class AriaTextModel(LlamaModel):
|
||||
|
||||
|
||||
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AriaTextConfig):
|
||||
super().__init__(config)
|
||||
@ -1355,6 +1366,8 @@ class AriaModel(LlavaModel):
|
||||
"""
|
||||
)
|
||||
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
|
||||
@ -300,23 +300,26 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
"attentions": ASTSelfAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
module.weight.copy_(
|
||||
nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to(
|
||||
module.weight.dtype
|
||||
)
|
||||
)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, ASTEmbeddings):
|
||||
module.cls_token.data.zero_()
|
||||
module.position_embeddings.data.zero_()
|
||||
module.distillation_token.data.zero_()
|
||||
module.cls_token.zero_()
|
||||
module.position_embeddings.zero_()
|
||||
module.distillation_token.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -264,6 +264,7 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of AudioFlamingo3 isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed
|
||||
@ -274,16 +275,16 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -435,10 +436,9 @@ class AudioFlamingo3MultiModalProjector(nn.Module):
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -446,9 +446,6 @@ class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, Gene
|
||||
self.audio_tower = AutoModel.from_config(config.audio_config)
|
||||
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -136,16 +136,12 @@ class AudioFlamingo3MultiModalProjector(VoxtralMultiModalProjector):
|
||||
"""
|
||||
)
|
||||
class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration):
|
||||
_tied_weights_keys = None
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Similar to Qwen2Audio
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
|
||||
def get_audio_features(
|
||||
self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor
|
||||
|
||||
@ -826,21 +826,22 @@ class AutoformerPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "past_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
|
||||
module._init_weight()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
# copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
|
||||
@ -90,7 +90,6 @@ class AyaVisionMultiModalProjector(nn.Module):
|
||||
@auto_docstring
|
||||
class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
config: AyaVisionConfig
|
||||
base_model_prefix = ""
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
@ -163,8 +162,6 @@ class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
|
||||
"""
|
||||
)
|
||||
class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: AyaVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
@ -333,12 +330,12 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
)
|
||||
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
r"^language_model.model": "model.language_model",
|
||||
r"^vision_tower": "model.vision_tower",
|
||||
r"^multi_modal_projector": "model.multi_modal_projector",
|
||||
r"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: AyaVisionConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1126,12 +1126,13 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -1383,7 +1384,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -800,12 +800,13 @@ class BambaPreTrainedModel(PreTrainedModel):
|
||||
# Note: only supports HybridMambaAttentionDynamicCache
|
||||
_is_stateful = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, BambaMixer):
|
||||
module.dt_bias.data.fill_(1.0)
|
||||
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
|
||||
module.D.data.fill_(1.0)
|
||||
module.dt_bias.fill_(1.0)
|
||||
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
|
||||
module.D.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -329,19 +329,21 @@ class BarkPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_supports_flash_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if getattr(module, "bias", None) is not None:
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -910,6 +912,9 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
# non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._tied_weights_keys = {}
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight"
|
||||
|
||||
# initialize a modified non causal GPT-like model
|
||||
# note that for there is one embedding layer and one lm_head for each codebook of Encodec
|
||||
@ -1025,25 +1030,6 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _tie_weights(self):
|
||||
if getattr(self.config, "tie_word_embeddings", True):
|
||||
self._tied_weights_keys = []
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
|
||||
for i in range(self.config.n_codes_total - self.config.n_codes_given):
|
||||
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
|
||||
self._tie_embedding_weights(output_embeddings[i], input_embeddings[i + 1])
|
||||
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1580,14 +1566,6 @@ class BarkModel(BarkPreTrainedModel, GenerationMixin):
|
||||
|
||||
return audio
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings list and the output embeddings list.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BarkFineModel",
|
||||
|
||||
@ -164,7 +164,7 @@ class BartConfig(PreTrainedConfig):
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.tie_encoder_decoder = True
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
|
||||
@ -476,19 +476,20 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -527,7 +528,7 @@ class BartEncoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -538,12 +539,9 @@ class BartEncoder(BartPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -674,7 +672,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -682,12 +680,9 @@ class BartDecoder(BartPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -899,7 +894,10 @@ class BartDecoder(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartModel(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
@ -908,24 +906,12 @@ class BartModel(BartPreTrainedModel):
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = BartEncoder(config, self.shared)
|
||||
self.decoder = BartDecoder(config, self.shared)
|
||||
self.encoder = BartEncoder(config)
|
||||
self.decoder = BartDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
|
||||
if self.shared.weight.device == torch.device(
|
||||
"meta"
|
||||
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
|
||||
self._tie_embedding_weights(self.shared, self.decoder.embed_tokens)
|
||||
else:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@ -1052,7 +1038,9 @@ class BartModel(BartPreTrainedModel):
|
||||
)
|
||||
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
@ -1086,11 +1074,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1240,8 +1223,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
)
|
||||
class BartForSequenceClassification(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BartModel(config)
|
||||
@ -1374,8 +1355,6 @@ class BartForSequenceClassification(BartPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BartForQuestionAnswering(BartPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1513,7 +1492,9 @@ class BartDecoderWrapper(BartPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -692,31 +692,32 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
|
||||
_supports_sdpa = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BeitEmbeddings):
|
||||
module.cls_token.data.zero_()
|
||||
module.cls_token.zero_()
|
||||
if module.mask_token is not None:
|
||||
module.mask_token.data.zero_()
|
||||
module.mask_token.zero_()
|
||||
if module.position_embeddings is not None:
|
||||
module.position_embeddings.data.zero_()
|
||||
module.position_embeddings.zero_()
|
||||
elif isinstance(module, BeitRelativePositionBias):
|
||||
module.relative_position_bias_table.data.zero_()
|
||||
module.relative_position_bias_table.zero_()
|
||||
elif isinstance(module, BeitLayer):
|
||||
if module.lambda_1 is not None:
|
||||
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_1.fill_(self.config.layer_scale_init_value)
|
||||
module.lambda_2.fill_(self.config.layer_scale_init_value)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -506,16 +506,9 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -569,21 +562,22 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BertLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -770,7 +764,10 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -864,7 +861,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -948,7 +948,10 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
|
||||
|
||||
@auto_docstring
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -456,21 +456,22 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": BertGenerationCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BertGenerationOnlyLMHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -629,20 +630,11 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
def _tie_weights(self):
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
@ -650,7 +642,10 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
"""
|
||||
)
|
||||
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1464,16 +1464,9 @@ class BigBirdLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1521,21 +1514,22 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BigBirdLMPredictionHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1899,7 +1893,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1999,7 +1996,10 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -2141,7 +2141,10 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1539,19 +1539,20 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -1574,7 +1575,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.attention_type = config.attention_type
|
||||
@ -1592,9 +1593,6 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
@ -1849,7 +1847,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -1861,9 +1859,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
||||
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
@ -2075,7 +2070,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
@ -2086,8 +2084,8 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.encoder = BigBirdPegasusEncoder(config, self.shared)
|
||||
self.decoder = BigBirdPegasusDecoder(config, self.shared)
|
||||
self.encoder = BigBirdPegasusEncoder(config)
|
||||
self.decoder = BigBirdPegasusDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -2100,11 +2098,6 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@ -2213,7 +2206,9 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
@ -2247,11 +2242,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||
self.register_buffer("final_logits_bias", new_bias)
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_embedding_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@auto_docstring
|
||||
# Ignore copy
|
||||
def forward(
|
||||
@ -2374,8 +2364,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
"""
|
||||
)
|
||||
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BigBirdPegasusModel(config)
|
||||
@ -2497,8 +2485,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2621,8 +2607,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
config.is_encoder_decoder = False
|
||||
|
||||
@ -510,7 +510,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -332,7 +332,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["output_projection.weight"]
|
||||
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -628,6 +628,7 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["BitEmbeddings"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
||||
|
||||
@ -433,7 +433,7 @@ class BitNetModel(BitNetPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ class BitNetModel(LlamaModel):
|
||||
|
||||
|
||||
class BitNetForCausalLM(LlamaForCausalLM):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = None
|
||||
_pp_plan = None
|
||||
|
||||
|
||||
@ -438,19 +438,20 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -474,7 +475,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -485,12 +486,9 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -623,7 +621,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -631,12 +629,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -852,7 +847,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -860,8 +858,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
self.encoder = BlenderbotEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotDecoder(config, self.shared)
|
||||
self.encoder = BlenderbotEncoder(config)
|
||||
self.decoder = BlenderbotDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1001,7 +999,9 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
@ -1184,7 +1184,9 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
|
||||
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -431,19 +431,20 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
@ -467,7 +468,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
@ -478,10 +479,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -612,7 +610,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
@ -620,10 +618,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
|
||||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
@ -838,7 +833,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
_tied_weights_keys = {
|
||||
"encoder.embed_tokens.weight": "shared.weight",
|
||||
"decoder.embed_tokens.weight": "shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -846,8 +844,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
|
||||
self.encoder = BlenderbotSmallEncoder(config, self.shared)
|
||||
self.decoder = BlenderbotSmallDecoder(config, self.shared)
|
||||
self.encoder = BlenderbotSmallEncoder(config)
|
||||
self.decoder = BlenderbotSmallDecoder(config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -974,7 +972,9 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
||||
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.shared.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
@ -1144,7 +1144,9 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
|
||||
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.weight": "model.decoder.embed_tokens.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
config.is_decoder = True
|
||||
|
||||
@ -419,13 +419,14 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
if isinstance(module, BlipVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config"):
|
||||
@ -443,10 +444,10 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class BlipEncoder(nn.Module):
|
||||
@ -797,8 +798,11 @@ class BlipModel(BlipPreTrainedModel):
|
||||
)
|
||||
class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
main_input_name = "pixel_values"
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
} # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -963,7 +967,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
config: BlipConfig
|
||||
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
|
||||
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config: BlipConfig):
|
||||
super().__init__(config)
|
||||
@ -971,7 +978,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
|
||||
self.vision_model = BlipVisionModel(config.vision_config)
|
||||
|
||||
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
|
||||
|
||||
self.text_decoder = BlipTextLMHeadModel(config.text_config)
|
||||
|
||||
self.decoder_pad_token_id = config.text_config.pad_token_id
|
||||
|
||||
@ -473,16 +473,9 @@ class BlipTextLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -511,15 +504,16 @@ class BlipTextPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
_no_split_modules = []
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
|
||||
@ -744,7 +738,10 @@ class BlipTextModel(BlipTextPreTrainedModel):
|
||||
|
||||
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
|
||||
class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"cls.predictions.decoder.bias": "cls.predictions.bias",
|
||||
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -409,19 +409,20 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_range
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.weight.normal_(mean=0.0, std=factor)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, Blip2VisionEmbeddings):
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
@ -435,7 +436,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
Blip2ForImageTextRetrieval,
|
||||
),
|
||||
):
|
||||
module.query_tokens.data.zero_()
|
||||
module.query_tokens.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
||||
@ -1049,10 +1050,6 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1076,11 +1073,6 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
@filter_out_non_signature_kwargs()
|
||||
@auto_docstring
|
||||
def get_text_features(
|
||||
@ -1612,10 +1604,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
|
||||
# Update _tied_weights_keys using the base model used.
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1639,11 +1627,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
|
||||
def _preprocess_accelerate(self):
|
||||
r"""
|
||||
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
||||
|
||||
@ -425,19 +425,20 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -722,7 +723,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"}
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -447,7 +447,6 @@ class BltCrossAttention(nn.Module):
|
||||
@auto_docstring
|
||||
class BltPreTrainedModel(PreTrainedModel):
|
||||
config: BltConfig
|
||||
base_model_prefix = ""
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BltTransformerLayer"]
|
||||
@ -1231,7 +1230,7 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin):
|
||||
config: BltConfig
|
||||
_can_compile_fullgraph = False
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
|
||||
|
||||
def __init__(self, config: BltConfig):
|
||||
super().__init__(config.get_text_config())
|
||||
|
||||
@ -964,7 +964,7 @@ class BltForCausalLM(MllamaForCausalLM):
|
||||
config: BltConfig
|
||||
_can_compile_fullgraph = False
|
||||
base_model_prefix = "model"
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
|
||||
|
||||
def __init__(self, config: BltConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -175,7 +175,6 @@ class BridgeTowerTextConfig(PreTrainedConfig):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
@ -298,7 +297,7 @@ class BridgeTowerConfig(PreTrainedConfig):
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["BridgeTowerConfig", "BridgeTowerTextConfig", "BridgeTowerVisionConfig"]
|
||||
|
||||
@ -919,6 +919,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
std = self.config.initializer_factor
|
||||
if isinstance(module, BridgeTowerVisionTransformer):
|
||||
@ -927,7 +928,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
fc_std = (2 * self.config.hidden_size) ** -0.5
|
||||
for block in module.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std)
|
||||
block.attn.in_proj_bias.data.zero_()
|
||||
block.attn.in_proj_bias.zero_()
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
|
||||
@ -935,15 +936,15 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std)
|
||||
nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.05 * std)
|
||||
module.weight.normal_(mean=0.0, std=0.05 * std)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BridgeTowerForContrastiveLearning):
|
||||
module.logit_scale.data.fill_(self.config.logit_scale_init_value)
|
||||
module.logit_scale.fill_(self.config.logit_scale_init_value)
|
||||
|
||||
if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
|
||||
@ -1497,7 +1498,7 @@ class BridgeTowerITMHead(nn.Module):
|
||||
"""
|
||||
)
|
||||
class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
||||
_tied_weights_keys = ["mlm_score.decoder.weight"]
|
||||
_tied_weights_keys = {"mlm_score.decoder.weight": "bridgetower.text_model.embeddings.word_embeddings.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -514,20 +514,21 @@ class BrosPreTrainedModel(PreTrainedModel):
|
||||
config: BrosConfig
|
||||
base_model_prefix = "bros"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, BrosRelationExtractor):
|
||||
nn.init.normal_(module.dummy_node, std=std)
|
||||
|
||||
|
||||
@ -383,7 +383,6 @@ class CamembertLMHead(nn.Module):
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
@ -395,14 +394,6 @@ class CamembertLMHead(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
# For accelerate compatibility and to not break backward compatibility
|
||||
if self.decoder.bias.device.type == "meta":
|
||||
self.decoder.bias = self.bias
|
||||
else:
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class CamembertPreTrainedModel(PreTrainedModel):
|
||||
@ -419,21 +410,22 @@ class CamembertPreTrainedModel(PreTrainedModel):
|
||||
"cross_attentions": CamembertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, CamembertLMHead):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class CamembertEmbeddings(nn.Module):
|
||||
@ -745,7 +737,10 @@ class CamembertModel(CamembertPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class CamembertForMaskedLM(CamembertPreTrainedModel):
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1191,7 +1186,10 @@ class CamembertForQuestionAnswering(CamembertPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "camembert.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -53,6 +53,11 @@ class CamembertModel(RobertaModel):
|
||||
|
||||
|
||||
class CamembertForMaskedLM(RobertaForMaskedLM):
|
||||
_tied_weights_keys = {
|
||||
"lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight",
|
||||
"lm_head.decoder.bias": "lm_head.bias",
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.camembert
|
||||
|
||||
@ -688,12 +688,11 @@ class CanineLMPredictionHead(nn.Module):
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
|
||||
hidden_states = self.transform(hidden_states)
|
||||
@ -720,19 +719,20 @@ class CaninePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "canine"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -1009,7 +1009,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -562,6 +562,7 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
@ -576,7 +577,7 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range)
|
||||
for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]:
|
||||
if embedding.padding_idx is not None:
|
||||
embedding.weight.data[embedding.padding_idx].zero_()
|
||||
embedding.weight[embedding.padding_idx].zero_()
|
||||
elif isinstance(module, ChineseCLIPVisionAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
@ -602,12 +603,12 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
|
||||
|
||||
@ -1308,28 +1308,29 @@ class ClapPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["audio", "text"]
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
|
||||
if isinstance(module, ClapTextEmbeddings):
|
||||
module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embeddings.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_type_embeddings.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, ClapModel):
|
||||
module.logit_scale_a.data.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_t.data.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_a.fill_(math.log(self.config.logit_scale_init_value))
|
||||
module.logit_scale_t.fill_(math.log(self.config.logit_scale_init_value))
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, (nn.Conv2d, nn.Linear)):
|
||||
in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
|
||||
nn.init.normal_(module.weight, std=in_proj_std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, ClapAudioSelfAttention):
|
||||
module.relative_position_bias_table.data.zero_()
|
||||
module.relative_position_bias_table.zero_()
|
||||
|
||||
|
||||
class ClapAudioModel(ClapPreTrainedModel):
|
||||
|
||||
@ -408,12 +408,13 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
"attentions": CLIPAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPTextEmbeddings):
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
@ -459,10 +460,10 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module):
|
||||
|
||||
@ -427,12 +427,13 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, CLIPSegTextEmbeddings):
|
||||
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, CLIPSegVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
@ -463,10 +464,10 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
|
||||
|
||||
@ -781,17 +781,18 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
||||
module.weight.normal_(mean=0.0, std=factor * 0.02)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, ClvpRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, ClvpEncoderMLP):
|
||||
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
@ -800,22 +801,22 @@ class ClvpPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, ClvpEncoder):
|
||||
config = self.config.get_text_config()
|
||||
factor = config.initializer_factor
|
||||
module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||
module.projection.weight.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5))
|
||||
elif isinstance(module, ClvpConditioningEncoder):
|
||||
module.mel_conv.weight.data.normal_(mean=0.0, std=factor)
|
||||
module.mel_conv.bias.data.zero_()
|
||||
module.mel_conv.weight.normal_(mean=0.0, std=factor)
|
||||
module.mel_conv.bias.zero_()
|
||||
elif isinstance(module, ClvpForCausalLM):
|
||||
for name, p in module.named_parameters():
|
||||
if name == "c_proj.weight":
|
||||
p.data.normal_(
|
||||
p.normal_(
|
||||
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers))
|
||||
)
|
||||
elif isinstance(module, ClvpModelForConditionalGeneration):
|
||||
module.logit_scale.data.fill_(self.config.logit_scale_init_value)
|
||||
module.logit_scale.fill_(self.config.logit_scale_init_value)
|
||||
|
||||
if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
class ClvpEncoder(ClvpPreTrainedModel):
|
||||
|
||||
@ -283,19 +283,20 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear,)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -560,7 +561,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -466,7 +466,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -447,7 +447,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -129,7 +129,6 @@ class Cohere2VisionCausalLMOutputWithPast(ModelOutput):
|
||||
@auto_docstring
|
||||
class Cohere2VisionPreTrainedModel(PreTrainedModel):
|
||||
config: Cohere2VisionConfig
|
||||
base_model_prefix = "model"
|
||||
input_modalities = ["image", "text"]
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
@ -143,6 +142,7 @@ class Cohere2VisionPreTrainedModel(PreTrainedModel):
|
||||
"hidden_states": "DecoderLayer",
|
||||
"attentions": "Attention",
|
||||
}
|
||||
base_model_prefix = "model"
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
@ -268,7 +268,7 @@ class Cohere2VisionModel(Cohere2VisionPreTrainedModel):
|
||||
)
|
||||
class Cohere2VisionForConditionalGeneration(Cohere2VisionPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
|
||||
def __init__(self, config: Cohere2VisionConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -144,7 +144,15 @@ def convert_colpali_weights_to_hf(
|
||||
|
||||
# Tie the weights (following ColPali's `__init__`` step)
|
||||
if model.vlm.language_model._tied_weights_keys is not None:
|
||||
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
|
||||
prefix = "vlm.language_model."
|
||||
prefixed_mapping = {
|
||||
f"{prefix}{target}": f"{prefix}{source}"
|
||||
for target, source in model.vlm.language_model._tied_weights_keys.items()
|
||||
}
|
||||
if isinstance(model._tied_weights_keys, dict):
|
||||
model._tied_weights_keys.update(prefixed_mapping)
|
||||
else:
|
||||
model._tied_weights_keys = prefixed_mapping
|
||||
|
||||
# Sanity check: ensure all keys are the same
|
||||
state_dict_keys_old = set(original_state_dict.keys())
|
||||
|
||||
@ -38,6 +38,7 @@ class ColPaliPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
@ -46,13 +47,13 @@ class ColPaliPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -113,7 +114,6 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
self.vocab_size = config.vlm_config.text_config.vocab_size
|
||||
|
||||
self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
||||
self._tied_weights_keys = [f"vlm.language_model.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.embedding_proj_layer = nn.Linear(
|
||||
@ -186,9 +186,6 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.vlm.set_output_embeddings(new_embeddings)
|
||||
|
||||
def tie_weights(self):
|
||||
return self.vlm.tie_weights()
|
||||
|
||||
def resize_token_embeddings(
|
||||
self,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
|
||||
@ -46,6 +46,7 @@ class ColQwen2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
@ -54,13 +55,13 @@ class ColQwen2PreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -118,7 +119,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
||||
self.config.vlm_config.text_config.hidden_size,
|
||||
self.embedding_dim,
|
||||
)
|
||||
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||
|
||||
self.post_init()
|
||||
|
||||
@ -222,9 +222,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.vlm.set_output_embeddings(new_embeddings)
|
||||
|
||||
def tie_weights(self):
|
||||
return self.vlm.tie_weights()
|
||||
|
||||
def resize_token_embeddings(
|
||||
self,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
|
||||
@ -304,7 +304,6 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
||||
def __init__(self, config: ColQwen2Config):
|
||||
super().__init__(config)
|
||||
del self._tied_weights_keys
|
||||
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
|
||||
@ -237,10 +237,10 @@ def replace_batch_norm(model):
|
||||
new_module = ConditionalDetrFrozenBatchNorm2d(module.num_features)
|
||||
|
||||
if module.weight.device != torch.device("meta"):
|
||||
new_module.weight.data.copy_(module.weight)
|
||||
new_module.bias.data.copy_(module.bias)
|
||||
new_module.running_mean.data.copy_(module.running_mean)
|
||||
new_module.running_var.data.copy_(module.running_var)
|
||||
new_module.weight.copy_(module.weight)
|
||||
new_module.bias.copy_(module.bias)
|
||||
new_module.running_mean.copy_(module.running_mean)
|
||||
new_module.running_var.copy_(module.running_var)
|
||||
|
||||
model._modules[name] = new_module
|
||||
|
||||
@ -970,6 +970,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
||||
input_modalities = "image"
|
||||
_no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
xavier_std = self.config.init_xavier_std
|
||||
@ -983,13 +984,13 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
||||
nn.init.uniform_(module.row_embeddings.weight)
|
||||
nn.init.uniform_(module.column_embeddings.weight)
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
module.weight.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
|
||||
|
||||
@ -108,24 +108,25 @@ class ConvBertPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "convbert"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, SeparableConv1D):
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, GroupedLinearLayer):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.bias.data.zero_()
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
class SeparableConv1D(nn.Module):
|
||||
@ -707,7 +708,7 @@ class ConvBertGeneratorPredictions(nn.Module):
|
||||
|
||||
@auto_docstring
|
||||
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
|
||||
_tied_weights_keys = ["generator.lm_head.weight"]
|
||||
_tied_weights_keys = {"generator_lm_head.weight": "convbert.embeddings.word_embeddings.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -240,18 +240,19 @@ class ConvNextPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["ConvNextLayer"]
|
||||
_can_record_outputs = {} # hidden states are collected explicitly
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, ConvNextLayer):
|
||||
if module.layer_scale_parameter is not None:
|
||||
module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value)
|
||||
module.layer_scale_parameter.fill_(self.config.layer_scale_init_value)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -260,18 +260,19 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
|
||||
input_modalities = "image"
|
||||
_no_split_modules = ["ConvNextV2Layer"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, ConvNextV2GRN):
|
||||
module.weight.data.zero_()
|
||||
module.bias.data.zero_()
|
||||
module.weight.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
|
||||
@auto_docstring
|
||||
|
||||
@ -525,23 +525,24 @@ class CpmAntPreTrainedModel(PreTrainedModel):
|
||||
config: CpmAntConfig
|
||||
base_model_prefix = "cpmant"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
||||
module.weight.normal_(mean=0.0, std=self.config.init_std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
||||
module.weight.normal_(mean=0.0, std=self.config.init_std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, CpmAntLayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, CpmAntSegmentPositionEmbedding):
|
||||
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
|
||||
module.relative_attention_bias.normal_(mean=0.0, std=self.config.init_std)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -698,7 +699,7 @@ class CpmAntModel(CpmAntPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "cpmant.input_embedding.weight"}
|
||||
|
||||
def __init__(self, config: CpmAntConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -409,12 +409,13 @@ class CsmPreTrainedModel(PreTrainedModel):
|
||||
"attentions": CsmAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, CsmCodebooksHead):
|
||||
num_codebooks = module.num_codebooks
|
||||
for i in range(num_codebooks - 1):
|
||||
module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -769,10 +770,9 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
_tied_weights_keys = [
|
||||
"backbone_model.embed_tokens.embed_audio_tokens.weight",
|
||||
"depth_decoder.model.embed_tokens.weight",
|
||||
]
|
||||
_tied_weights_keys = {
|
||||
"backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight"
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -790,13 +790,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
def set_input_embeddings(self, value):
|
||||
self.backbone_model.embed_tokens = value
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_codebooks_embeddings:
|
||||
self._tie_embedding_weights(
|
||||
self.backbone_model.embed_tokens.embed_audio_tokens,
|
||||
self.depth_decoder.model.embed_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("output_loading_info", False):
|
||||
|
||||
@ -140,12 +140,13 @@ class CsmPreTrainedModel(PreTrainedModel):
|
||||
"attentions": CsmAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, CsmCodebooksHead):
|
||||
num_codebooks = module.num_codebooks
|
||||
for i in range(num_codebooks - 1):
|
||||
module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -420,10 +421,9 @@ class CsmBackboneModel(LlamaModel):
|
||||
"""
|
||||
)
|
||||
class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
_tied_weights_keys = [
|
||||
"backbone_model.embed_tokens.embed_audio_tokens.weight",
|
||||
"depth_decoder.model.embed_tokens.weight",
|
||||
]
|
||||
_tied_weights_keys = {
|
||||
"backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight"
|
||||
}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -441,13 +441,6 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
|
||||
def set_input_embeddings(self, value):
|
||||
self.backbone_model.embed_tokens = value
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_codebooks_embeddings:
|
||||
self._tie_embedding_weights(
|
||||
self.backbone_model.embed_tokens.embed_audio_tokens,
|
||||
self.depth_decoder.model.embed_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
if kwargs.get("output_loading_info", False):
|
||||
|
||||
@ -188,19 +188,20 @@ class CTRLPreTrainedModel(PreTrainedModel):
|
||||
config: CTRLConfig
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear, Conv1D)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
module.weight[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
@ -384,7 +385,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
"""
|
||||
)
|
||||
class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "transformer.w.weight"}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -489,19 +489,20 @@ class CvtPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["CvtLayer"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.copy_(nn.init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range))
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
elif isinstance(module, CvtStage):
|
||||
if self.config.cls_token[module.stage]:
|
||||
module.cls_token.data = nn.init.trunc_normal_(
|
||||
module.cls_token.data, mean=0.0, std=self.config.initializer_range
|
||||
module.cls_token.copy_(
|
||||
nn.init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -437,7 +437,7 @@ class CwmModel(CwmPreTrainedModel):
|
||||
|
||||
@auto_docstring
|
||||
class CwmForCausalLM(CwmPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
|
||||
@ -396,6 +396,7 @@ class DFineConfig(PreTrainedConfig):
|
||||
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
||||
)
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
self.tie_encoder_decoder = True
|
||||
|
||||
|
||||
__all__ = ["DFineConfig"]
|
||||
|
||||
@ -444,6 +444,7 @@ class DFinePreTrainedModel(PreTrainedModel):
|
||||
input_modalities = "image"
|
||||
_no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"]
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
# initialize linear layer bias value according to a given probability value.
|
||||
@ -467,7 +468,7 @@ class DFinePreTrainedModel(PreTrainedModel):
|
||||
module.up.fill_(self.config.up)
|
||||
|
||||
if isinstance(module, DFineMultiscaleDeformableAttention):
|
||||
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
||||
nn.init.constant_(module.sampling_offsets.weight, 0.0)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
||||
2.0 * math.pi / module.n_heads
|
||||
@ -478,10 +479,10 @@ class DFinePreTrainedModel(PreTrainedModel):
|
||||
scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
|
||||
grid_init *= scaling
|
||||
with torch.no_grad():
|
||||
module.sampling_offsets.bias.data[...] = grid_init.flatten()
|
||||
module.sampling_offsets.bias[...] = grid_init.flatten()
|
||||
|
||||
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
||||
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
||||
nn.init.constant_(module.attention_weights.weight, 0.0)
|
||||
nn.init.constant_(module.attention_weights.bias, 0.0)
|
||||
|
||||
if isinstance(module, DFineModel):
|
||||
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||
@ -490,9 +491,9 @@ class DFinePreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(module.enc_score_head.bias, bias)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
if isinstance(module, DFineGate):
|
||||
bias = float(-math.log((1 - 0.5) / 0.5))
|
||||
@ -504,8 +505,8 @@ class DFinePreTrainedModel(PreTrainedModel):
|
||||
init.constant_(module.reg_conf.layers[-1].weight, 0)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
||||
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
||||
@ -947,10 +948,10 @@ def replace_batch_norm(model):
|
||||
new_module = DFineFrozenBatchNorm2d(module.num_features)
|
||||
|
||||
if module.weight.device != torch.device("meta"):
|
||||
new_module.weight.data.copy_(module.weight)
|
||||
new_module.bias.data.copy_(module.bias)
|
||||
new_module.running_mean.data.copy_(module.running_mean)
|
||||
new_module.running_var.data.copy_(module.running_var)
|
||||
new_module.weight.copy_(module.weight)
|
||||
new_module.bias.copy_(module.bias)
|
||||
new_module.running_mean.copy_(module.running_mean)
|
||||
new_module.running_var.copy_(module.running_var)
|
||||
|
||||
model._modules[name] = new_module
|
||||
|
||||
@ -1547,9 +1548,14 @@ class DFineObjectDetectionOutput(ModelOutput):
|
||||
)
|
||||
class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
||||
_tied_weights_keys = ["bbox_embed", "class_embed"]
|
||||
# We can't initialize the model on meta device as some weights are modified during the initialization
|
||||
_no_split_modules = None
|
||||
_tied_weights_keys = {
|
||||
r"bbox_embed.(?![0])\d+": "bbox_embed.0",
|
||||
r"class_embed.(?![0])\d+": "class_embed.0",
|
||||
"model.decoder.class_embed": "class_embed",
|
||||
"model.decoder.bbox_embed": "bbox_embed",
|
||||
}
|
||||
|
||||
def __init__(self, config: DFineConfig):
|
||||
super().__init__(config)
|
||||
@ -1571,10 +1577,8 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
]
|
||||
)
|
||||
|
||||
# here self.model.decoder.bbox_embed is null, but not self.bbox_embed
|
||||
self.model.decoder.class_embed = self.class_embed
|
||||
self.model.decoder.bbox_embed = self.bbox_embed
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
@ -415,6 +415,7 @@ class DFineConfig(PreTrainedConfig):
|
||||
f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}"
|
||||
)
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
self.tie_encoder_decoder = True
|
||||
|
||||
|
||||
class DFineMultiscaleDeformableAttention(nn.Module):
|
||||
@ -588,6 +589,7 @@ class DFineDecoderLayer(RTDetrDecoderLayer):
|
||||
|
||||
|
||||
class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
# initialize linear layer bias value according to a given probability value.
|
||||
if isinstance(module, (DFineForObjectDetection, DFineDecoder)):
|
||||
@ -610,7 +612,7 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
||||
module.up.fill_(self.config.up)
|
||||
|
||||
if isinstance(module, DFineMultiscaleDeformableAttention):
|
||||
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
|
||||
nn.init.constant_(module.sampling_offsets.weight, 0.0)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
|
||||
2.0 * math.pi / module.n_heads
|
||||
@ -621,10 +623,10 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
||||
scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1)
|
||||
grid_init *= scaling
|
||||
with torch.no_grad():
|
||||
module.sampling_offsets.bias.data[...] = grid_init.flatten()
|
||||
module.sampling_offsets.bias[...] = grid_init.flatten()
|
||||
|
||||
nn.init.constant_(module.attention_weights.weight.data, 0.0)
|
||||
nn.init.constant_(module.attention_weights.bias.data, 0.0)
|
||||
nn.init.constant_(module.attention_weights.weight, 0.0)
|
||||
nn.init.constant_(module.attention_weights.bias, 0.0)
|
||||
|
||||
if isinstance(module, DFineModel):
|
||||
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
|
||||
@ -633,9 +635,9 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
||||
nn.init.constant_(module.enc_score_head.bias, bias)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
module.bias.zero_()
|
||||
|
||||
if isinstance(module, DFineGate):
|
||||
bias = float(-math.log((1 - 0.5) / 0.5))
|
||||
@ -647,8 +649,8 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel):
|
||||
init.constant_(module.reg_conf.layers[-1].weight, 0)
|
||||
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
module.weight.fill_(1.0)
|
||||
module.bias.zero_()
|
||||
|
||||
if hasattr(module, "weight_embedding") and self.config.learn_initial_query:
|
||||
nn.init.xavier_uniform_(module.weight_embedding.weight)
|
||||
@ -874,7 +876,17 @@ class DFineModel(RTDetrModel):
|
||||
self.decoder = DFineDecoder(config)
|
||||
|
||||
|
||||
class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
|
||||
class DFineForObjectDetection(RTDetrForObjectDetection):
|
||||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
||||
# We can't initialize the model on meta device as some weights are modified during the initialization
|
||||
_no_split_modules = None
|
||||
_tied_weights_keys = {
|
||||
r"bbox_embed.(?![0])\d+": "bbox_embed.0",
|
||||
r"class_embed.(?![0])\d+": "class_embed.0",
|
||||
"model.decoder.class_embed": "class_embed",
|
||||
"model.decoder.bbox_embed": "bbox_embed",
|
||||
}
|
||||
|
||||
def __init__(self, config: DFineConfig):
|
||||
DFinePreTrainedModel.__init__(self, config)
|
||||
|
||||
@ -895,10 +907,8 @@ class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
|
||||
]
|
||||
)
|
||||
|
||||
# here self.model.decoder.bbox_embed is null, but not self.bbox_embed
|
||||
self.model.decoder.class_embed = self.class_embed
|
||||
self.model.decoder.bbox_embed = self.bbox_embed
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user