Fix loading older state_dict into AdamW after refactor (#144972)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144972
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2025-01-16 16:44:00 +00:00
committed by PyTorch MergeBot
parent b8abdaa286
commit 3908be676c
2 changed files with 27 additions and 0 deletions

View File

@ -671,9 +671,28 @@ class TestOptimRenewed(TestCase):
loaded_dict = optim.state_dict()
# Test that Adam respects the decoupled_weight_decay key
new_optim = torch.optim.Adam(model.parameters())
new_optim.load_state_dict(loaded_dict)
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
# Test that decoupled_weight_decay is always True for AdamW
adam_optim = torch.optim.Adam(model.parameters())
adam_state_dict = adam_optim.state_dict()
self.assertFalse(adam_state_dict["param_groups"][0]["decoupled_weight_decay"])
new_optim = torch.optim.AdamW(model.parameters())
new_optim.load_state_dict(adam_state_dict)
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
# Test that state_dicts from the old AdamW (with no decoupled_weight_decay key)
# will have decoupled_weight_decay=True in new AdamW:
old_adamw_dict = deepcopy(loaded_dict)
del old_adamw_dict["param_groups"][0]["decoupled_weight_decay"]
self.assertFalse("decoupled_weight_decay" in old_adamw_dict["param_groups"][0])
new_optim = torch.optim.AdamW(model.parameters())
new_optim.load_state_dict(old_adamw_dict)
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
def _compare_between(

View File

@ -49,6 +49,14 @@ class AdamW(Adam):
decoupled_weight_decay=True,
)
# Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following
# guarantees that decoupled_weight_decay will always be True for loading any state into
# AdamW
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group["decoupled_weight_decay"] = True
AdamW.__doc__ = (
r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.