mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Functional Optim] Test kwargs parity for SGD (#62078)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62078 Ensure that kwarg arguments such as momentum and weight decay maintain parity between optimizer.step and step_param. ghstack-source-id: 134330377 Test Plan: CI Reviewed By: SciPioneer Differential Revision: D29837942 fbshipit-source-id: 1ae39648fc26aebd8aaef1a7ac0e03b598a8ed60
This commit is contained in:
committed by
Facebook GitHub Bot
parent
478098aaac
commit
c0ebeca1a8
@ -38,7 +38,9 @@ class TestFunctionalOptimParity(TestCase):
|
||||
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
|
||||
if not functional_optim_cls:
|
||||
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
|
||||
optim_functional = functional_optim_cls([], *args, allow_empty_param_list=True)
|
||||
optim_functional = functional_optim_cls(
|
||||
[], *args, **kwargs, allow_empty_param_list=True
|
||||
)
|
||||
if not hasattr(optim_functional, "step_param"):
|
||||
raise ValueError(
|
||||
f"Functional optimizer class {optim_functional} must implement step_param method."
|
||||
@ -89,7 +91,7 @@ class TestFunctionalOptimParity(TestCase):
|
||||
"Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137",
|
||||
)
|
||||
def test_functional_optim_parity(self):
|
||||
self._test_functional_optim_parity(SGD, 1e-2)
|
||||
self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user