Compare commits

...

2 Commits

Author SHA1 Message Date
b345f4e7f7 warning was in info mode. 2024-03-26 20:15:07 +09:00
dbf00d5895 grumble 2024-03-25 22:04:26 +09:00

View File

@ -595,12 +595,13 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
state_dict._metadata = metadata
error_msgs = []
unexpected_keys, missing_keys = [], []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
@ -630,7 +631,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# it's safe to delete it.
del state_dict
return error_msgs
return error_msgs, unexpected_keys, missing_keys
def find_submodule_and_param_name(model, long_key, start_prefix):
@ -3901,7 +3902,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
offload_index = None
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
@ -3974,8 +3975,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
mmsg, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs += mmsg
# force memory release
del state_dict
gc.collect()
@ -4009,9 +4010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"