[Functional Optim] Test kwargs parity for SGD (#62078)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62078

Ensure that kwarg arguments such as momentum and weight decay maintain
parity between optimizer.step and step_param.
ghstack-source-id: 134330377

Test Plan: CI

Reviewed By: SciPioneer

Differential Revision: D29837942

fbshipit-source-id: 1ae39648fc26aebd8aaef1a7ac0e03b598a8ed60
This commit is contained in:
Rohan Varma
2021-07-26 22:03:47 -07:00
committed by Facebook GitHub Bot
parent 478098aaac
commit c0ebeca1a8

View File

@ -38,7 +38,9 @@ class TestFunctionalOptimParity(TestCase):
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
optim_functional = functional_optim_cls([], *args, allow_empty_param_list=True)
optim_functional = functional_optim_cls(
[], *args, **kwargs, allow_empty_param_list=True
)
if not hasattr(optim_functional, "step_param"):
raise ValueError(
f"Functional optimizer class {optim_functional} must implement step_param method."
@ -89,7 +91,7 @@ class TestFunctionalOptimParity(TestCase):
"Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137",
)
def test_functional_optim_parity(self):
self._test_functional_optim_parity(SGD, 1e-2)
self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01)
if __name__ == "__main__":