mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
Only make a shallow copy when loading optimizer state_dict (#106082)
The thing we do still deep copy is the param_groups, which is much lighter weight. This should also save memory when loading from a checkpoint. The deepcopy was introduced inecfcf39f30, but module.py had only a shallow copy at that point so it did not actually bring parity. Incorporates an XLA fix, which is why I'm updating the pin toca5eab87a7Pull Request resolved: https://github.com/pytorch/pytorch/pull/106082 Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
ceea08a986
commit
59d0dea90f
@ -712,8 +712,8 @@ class Optimizer:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
# deepcopy, to be consistent with module API
|
||||
state_dict = deepcopy(state_dict)
|
||||
# shallow copy, to be consistent with module API
|
||||
state_dict = state_dict.copy()
|
||||
|
||||
for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
|
||||
hook_result = pre_hook(self, state_dict)
|
||||
@ -722,7 +722,9 @@ class Optimizer:
|
||||
|
||||
# Validate the state_dict
|
||||
groups = self.param_groups
|
||||
saved_groups = state_dict['param_groups']
|
||||
|
||||
# Deepcopy as we write into saved_groups later to update state
|
||||
saved_groups = deepcopy(state_dict['param_groups'])
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of "
|
||||
|
||||
Reference in New Issue
Block a user