mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix loading older state_dict into AdamW after refactor (#144972)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144972 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8abdaa286
commit
3908be676c
@ -671,9 +671,28 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
loaded_dict = optim.state_dict()
|
||||
|
||||
# Test that Adam respects the decoupled_weight_decay key
|
||||
new_optim = torch.optim.Adam(model.parameters())
|
||||
new_optim.load_state_dict(loaded_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
# Test that decoupled_weight_decay is always True for AdamW
|
||||
adam_optim = torch.optim.Adam(model.parameters())
|
||||
adam_state_dict = adam_optim.state_dict()
|
||||
self.assertFalse(adam_state_dict["param_groups"][0]["decoupled_weight_decay"])
|
||||
|
||||
new_optim = torch.optim.AdamW(model.parameters())
|
||||
new_optim.load_state_dict(adam_state_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
# Test that state_dicts from the old AdamW (with no decoupled_weight_decay key)
|
||||
# will have decoupled_weight_decay=True in new AdamW:
|
||||
old_adamw_dict = deepcopy(loaded_dict)
|
||||
del old_adamw_dict["param_groups"][0]["decoupled_weight_decay"]
|
||||
self.assertFalse("decoupled_weight_decay" in old_adamw_dict["param_groups"][0])
|
||||
|
||||
new_optim = torch.optim.AdamW(model.parameters())
|
||||
new_optim.load_state_dict(old_adamw_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
def _compare_between(
|
||||
|
@ -49,6 +49,14 @@ class AdamW(Adam):
|
||||
decoupled_weight_decay=True,
|
||||
)
|
||||
|
||||
# Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following
|
||||
# guarantees that decoupled_weight_decay will always be True for loading any state into
|
||||
# AdamW
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group["decoupled_weight_decay"] = True
|
||||
|
||||
|
||||
AdamW.__doc__ = (
|
||||
r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.
|
||||
|
Reference in New Issue
Block a user