mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix loading older state_dict into AdamW after refactor (#144972)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144972 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
b8abdaa286
commit
3908be676c
@ -671,9 +671,28 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
loaded_dict = optim.state_dict()
|
||||
|
||||
# Test that Adam respects the decoupled_weight_decay key
|
||||
new_optim = torch.optim.Adam(model.parameters())
|
||||
new_optim.load_state_dict(loaded_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
# Test that decoupled_weight_decay is always True for AdamW
|
||||
adam_optim = torch.optim.Adam(model.parameters())
|
||||
adam_state_dict = adam_optim.state_dict()
|
||||
self.assertFalse(adam_state_dict["param_groups"][0]["decoupled_weight_decay"])
|
||||
|
||||
new_optim = torch.optim.AdamW(model.parameters())
|
||||
new_optim.load_state_dict(adam_state_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
# Test that state_dicts from the old AdamW (with no decoupled_weight_decay key)
|
||||
# will have decoupled_weight_decay=True in new AdamW:
|
||||
old_adamw_dict = deepcopy(loaded_dict)
|
||||
del old_adamw_dict["param_groups"][0]["decoupled_weight_decay"]
|
||||
self.assertFalse("decoupled_weight_decay" in old_adamw_dict["param_groups"][0])
|
||||
|
||||
new_optim = torch.optim.AdamW(model.parameters())
|
||||
new_optim.load_state_dict(old_adamw_dict)
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
def _compare_between(
|
||||
|
Reference in New Issue
Block a user