Compare commits

...

5 Commits

Author SHA1 Message Date
0d986a1c0a Merge branch 'main' into base-model-loading 2025-03-07 10:56:51 +01:00
55ced4dd9f smol fix 2025-03-06 14:58:04 +01:00
cbf5580f69 fix 2025-03-06 12:21:50 +01:00
5b37639b73 style 2025-03-06 12:19:51 +01:00
85e8f9a5cf fix 2025-03-06 12:17:17 +01:00

View File

@ -819,15 +819,15 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for serialized_param_name, empty_param in state_dict.items():
# we shouldn't rename the key (add/remove prefix) before checking its value in expected keys but it's fine for legacy params
serialized_param_name, _ = model._fix_state_dict_key_on_load(serialized_param_name)
if serialized_param_name not in expected_keys:
continue
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)