mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
10 Commits
v4.46.3
...
safe_seria
Author | SHA1 | Date | |
---|---|---|---|
d4ba7b4849 | |||
1a5c1c0c3d | |||
34289174a7 | |||
d2207b32e0 | |||
cdaf56e9c3 | |||
7550456225 | |||
b0c20e8251 | |||
2c81acf01c | |||
8bbe608f01 | |||
2fe03dfceb |
@ -1718,6 +1718,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for ignore_key in self._keys_to_ignore_on_save:
|
||||
if ignore_key in state_dict.keys():
|
||||
del state_dict[ignore_key]
|
||||
# Disable to see the damage.
|
||||
if safe_serialization:
|
||||
if self._keys_to_ignore_on_load_missing is not None:
|
||||
from collections import defaultdict
|
||||
|
||||
ptrs = defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
ptrs[tensor.data_ptr()].append(name)
|
||||
|
||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||
|
||||
for _, names in shared_ptrs.items():
|
||||
for name in names:
|
||||
for pat in self._keys_to_ignore_on_load_missing:
|
||||
if re.search(pat, name):
|
||||
if name in state_dict:
|
||||
del state_dict[name]
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
|
@ -1219,6 +1219,11 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
class Blip2Model(Blip2PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"language_model.lm_head.weight",
|
||||
r"language_model.encoder.embed_tokens.weight",
|
||||
r"language_model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -1241,6 +1246,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
def _tie_weights(self):
|
||||
if not self.config.use_decoder_only_language_model:
|
||||
self.language_model.encoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.decoder.embed_tokens = self.language_model.shared
|
||||
self.language_model.lm_head = self.language_model.shared
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
|
||||
def get_text_features(
|
||||
self,
|
||||
@ -1553,6 +1564,11 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
# _keys_to_ignore_on_load_missing = [
|
||||
# r"language_model.lm_head.weight",
|
||||
# r"language_model.encoder.embed_tokens.weight",
|
||||
# r"language_model.decoder.embed_tokens.weight",
|
||||
# ]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
return nn.ModuleList([module for i in range(N)])
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
@ -1778,7 +1778,7 @@ class DetaModel(DetaPreTrainedModel):
|
||||
)
|
||||
class DetaForObjectDetection(DetaPreTrainedModel):
|
||||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
||||
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]
|
||||
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*", "model.decoder"]
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
|
||||
def __init__(self, config: DetaConfig):
|
||||
@ -1817,7 +1817,6 @@ class DetaForObjectDetection(DetaPreTrainedModel):
|
||||
self.model.decoder.class_embed = self.class_embed
|
||||
for box_embed in self.bbox_embed:
|
||||
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -630,8 +630,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LlamaModel(config)
|
||||
|
@ -357,9 +357,10 @@ class Pix2StructConfig(PretrainedConfig):
|
||||
initializer_factor=1.0,
|
||||
initializer_range=0.02,
|
||||
is_vqa=False,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
|
@ -27,6 +27,7 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@ -1626,6 +1627,41 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
@require_safetensors
|
||||
def test_can_use_safetensors(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
try:
|
||||
model_tied.save_pretrained(d, safe_serialization=True)
|
||||
except Exception as e:
|
||||
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
|
||||
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
# Checking the state dicts are correct
|
||||
reloaded_state = model_reloaded.state_dict()
|
||||
for k, v in model_tied.state_dict().items():
|
||||
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
|
||||
torch.testing.assert_close(
|
||||
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
|
||||
)
|
||||
|
||||
# Checking the tensor sharing are correct
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in model_tied.state_dict().items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
|
||||
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
|
||||
|
||||
for _, shared_names in shared_ptrs.items():
|
||||
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
|
||||
self.assertEqual(
|
||||
len(reloaded_ptrs),
|
||||
1,
|
||||
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
|
||||
)
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
Reference in New Issue
Block a user