Compare commits

...

10 Commits

Author SHA1 Message Date
69af6d10cf fix 2023-12-18 15:26:23 +01:00
2ebe97444a fix 2023-12-18 15:12:00 +01:00
16a4ad8c7f fix 2023-12-18 15:06:24 +01:00
69e984170a add test 2023-12-18 14:17:01 +01:00
b79c54a0a3 fix 2023-12-15 17:10:49 +01:00
67ba4c56cf fix 2023-12-15 16:10:31 +01:00
2f58425101 fix 2023-12-15 16:03:51 +01:00
0668589af0 fix 2023-12-15 15:28:20 +01:00
dec51b6c11 fix 2023-12-15 14:53:12 +01:00
0c60daa79f fix 2023-12-15 14:48:47 +01:00
2 changed files with 75 additions and 17 deletions

View File

@ -539,15 +539,24 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
def set_initialized_submodules(model, state_dict_keys):
def set_initialized_submodules(model, state_dict_keys, loaded=True):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
if loaded:
loaded_keys = [
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")
]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = loaded
else:
not_loaded_keys = [
k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")
]
if set(module.state_dict().keys()) == set(not_loaded_keys):
module._is_hf_initialized = False
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
@ -3955,14 +3964,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)
def checkpoint_key_to_model_key(key, remove_prefix_from_model, add_prefix_to_model):
model_key = _fix_key(key)
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = key[len(prefix) + 1 :]
return model_key
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
_loaded_keys = [
checkpoint_key_to_model_key(k, remove_prefix_from_model, add_prefix_to_model) for k in loaded_keys
]
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)
@ -4004,13 +4021,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
model_key = checkpoint_key_to_model_key(
checkpoint_key, remove_prefix_from_model, add_prefix_to_model
)
if (
model_key in model_state_dict
@ -4157,6 +4170,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if _fast_init:
mismatched_model_keys = [
checkpoint_key_to_model_key(x[0], remove_prefix_from_model, add_prefix_to_model)
for x in mismatched_keys
]
set_initialized_submodules(model, mismatched_model_keys, loaded=False)
# This will only initialize submodules that are re-marked as `not loaded` above due to mismatched
model.apply(model._initialize_weights)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:

View File

@ -2889,6 +2889,42 @@ class ModelTesterMixin:
else:
new_model_without_prefix(input_ids)
def test_mismatched_shapes_have_properly_initialized_weights(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(configs_no_init)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError):
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
for name, param in new_model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_model_is_small(self):
# Just a consistency check to make sure we are not running tests on 80M parameter models.
config, _ = self.model_tester.prepare_config_and_inputs_for_common()