Compare commits

...

10 Commits

6 changed files with 73 additions and 6 deletions

View File

@ -1718,6 +1718,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for ignore_key in self._keys_to_ignore_on_save:
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
# Disable to see the damage.
if safe_serialization:
if self._keys_to_ignore_on_load_missing is not None:
from collections import defaultdict
ptrs = defaultdict(list)
for name, tensor in state_dict.items():
ptrs[tensor.data_ptr()].append(name)
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
for _, names in shared_ptrs.items():
for name in names:
for pat in self._keys_to_ignore_on_load_missing:
if re.search(pat, name):
if name in state_dict:
del state_dict[name]
# Shard the model if it is too big.
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME

View File

@ -1219,6 +1219,11 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
class Blip2Model(Blip2PreTrainedModel):
config_class = Blip2Config
main_input_name = "pixel_values"
_keys_to_ignore_on_load_missing = [
r"language_model.lm_head.weight",
r"language_model.encoder.embed_tokens.weight",
r"language_model.decoder.embed_tokens.weight",
]
def __init__(self, config: Blip2Config):
super().__init__(config)
@ -1241,6 +1246,12 @@ class Blip2Model(Blip2PreTrainedModel):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
self.language_model.lm_head = self.language_model.shared
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
@ -1553,6 +1564,11 @@ class Blip2Model(Blip2PreTrainedModel):
class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
config_class = Blip2Config
main_input_name = "pixel_values"
# _keys_to_ignore_on_load_missing = [
# r"language_model.lm_head.weight",
# r"language_model.encoder.embed_tokens.weight",
# r"language_model.decoder.embed_tokens.weight",
# ]
def __init__(self, config: Blip2Config):
super().__init__(config)

View File

@ -244,7 +244,7 @@ class DetaObjectDetectionOutput(ModelOutput):
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
return nn.ModuleList([module for i in range(N)])
def inverse_sigmoid(x, eps=1e-5):
@ -1778,7 +1778,7 @@ class DetaModel(DetaPreTrainedModel):
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*", "model.decoder"]
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection.__init__ with DeformableDetr->Deta
def __init__(self, config: DetaConfig):
@ -1817,7 +1817,6 @@ class DetaForObjectDetection(DetaPreTrainedModel):
self.model.decoder.class_embed = self.class_embed
for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
# Initialize weights and apply final processing
self.post_init()

View File

@ -630,8 +630,6 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)

View File

@ -357,9 +357,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_factor=1.0,
initializer_range=0.02,
is_vqa=False,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
if text_config is None:
text_config = {}

View File

@ -27,6 +27,7 @@ import tempfile
import unittest
import unittest.mock as mock
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@ -1626,6 +1627,41 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@require_safetensors
def test_can_use_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
try:
model_tied.save_pretrained(d, safe_serialization=True)
except Exception as e:
raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}")
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# Checking the state dicts are correct
reloaded_state = model_reloaded.state_dict()
for k, v in model_tied.state_dict().items():
self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded")
torch.testing.assert_close(
v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}"
)
# Checking the tensor sharing are correct
ptrs = defaultdict(list)
for k, v in model_tied.state_dict().items():
ptrs[v.data_ptr()].append(k)
shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1}
for _, shared_names in shared_ptrs.items():
reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names}
self.assertEqual(
len(reloaded_ptrs),
1,
f"The shared pointers are incorrect, found different pointers for keys {shared_names}",
)
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: