Compare commits

...

2 Commits

Author SHA1 Message Date
f6154fad71 comment 2025-03-28 16:23:29 +01:00
2b4861b831 fix 2025-03-27 08:37:28 +01:00

View File

@ -780,6 +780,9 @@ def _load_state_dict_into_meta_model(
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
# Used to fix the issue mentioned in #37031: when loading a model with tied weights in state_dict + `tie_word_embeddings = False`,
# we need to make sure they are not loaded as tied weights!
data_ptrs = set()
for param_name, empty_param in state_dict.items():
if param_name not in expected_keys:
continue
@ -849,11 +852,19 @@ def _load_state_dict_into_meta_model(
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
module, param_type = get_module_from_name(model, param_name)
# avoid tied weights
if param.data_ptr() in data_ptrs:
param = param.clone()
module.load_state_dict(
{param_type: param.to(param_device)},
strict=False,
assign=True,
)
# Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights
data_ptrs.add(model.state_dict()[param_name].data_ptr())
else:
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys