mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Parity tests for functional optimizer step_param (#61756)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61756 DDP will support running optimizer as communication hook with optimizers that support a per-parameter/gradient step function `step_param`. Add parity tests as we implement more optimizers that support step_param to ensure parity with regular optimizers. ghstack-source-id: 134330378 Test Plan: Ci Reviewed By: SciPioneer Differential Revision: D29727549 fbshipit-source-id: 18977c896f12b8e478298488b298fd107affcf5f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b6d10a3a27
commit
69adb21940
96
test/test_functional_optim.py
Normal file
96
test/test_functional_optim.py
Normal file
@ -0,0 +1,96 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import SGD
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
|
||||
if not IS_WINDOWS:
|
||||
from torch.distributed.optim.functional_sgd import _FunctionalSGD
|
||||
_SUPPORTED_OPTIM_MAPPING = {
|
||||
SGD: _FunctionalSGD,
|
||||
}
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.lin1 = nn.Linear(3, 3, bias=False)
|
||||
self.lin2 = nn.Linear(3, 3, bias=False)
|
||||
|
||||
def forward(self, t1):
|
||||
return self.lin2(F.relu(self.lin1(t1)))
|
||||
|
||||
|
||||
class TestFunctionalOptimParity(TestCase):
|
||||
def _validate_parameters(self, params_1, params_2):
|
||||
for p1, p2 in zip(params_1, params_2):
|
||||
self.assertEqual(p1, p2)
|
||||
|
||||
def _test_functional_optim_parity(self, optim_cls, *args, **kwargs):
|
||||
module_optim = MyModule()
|
||||
module_functional = MyModule()
|
||||
optim_params = module_optim.parameters()
|
||||
functional_params = module_functional.parameters()
|
||||
optim = optim_cls(optim_params, *args, **kwargs)
|
||||
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
|
||||
if not functional_optim_cls:
|
||||
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
|
||||
optim_functional = functional_optim_cls([], *args, allow_empty_param_list=True)
|
||||
if not hasattr(optim_functional, "step_param"):
|
||||
raise ValueError(
|
||||
f"Functional optimizer class {optim_functional} must implement step_param method."
|
||||
)
|
||||
|
||||
# Initial weights should match
|
||||
self._validate_parameters(
|
||||
module_optim.parameters(), module_functional.parameters()
|
||||
)
|
||||
# Save old parameters to verify optimizer modifies them.
|
||||
old_module_optim_params = [
|
||||
param.clone().detach() for param in module_optim.parameters()
|
||||
]
|
||||
old_module_functional_params = [
|
||||
param.clone().detach() for param in module_functional.parameters()
|
||||
]
|
||||
|
||||
t1 = torch.randn(3, 3)
|
||||
for _ in range(10):
|
||||
module_optim.zero_grad()
|
||||
module_functional.zero_grad()
|
||||
# Forward + Backward
|
||||
optim_out = module_optim(t1).sum()
|
||||
functional_out = module_functional(t1).sum()
|
||||
optim_out.backward()
|
||||
functional_out.backward()
|
||||
# Optimizer step
|
||||
optim.step()
|
||||
# Functional optimizer step_param
|
||||
for param in module_functional.parameters():
|
||||
grad = param.grad
|
||||
optim_functional.step_param(param, grad)
|
||||
|
||||
# Validate parameters are equal
|
||||
for optim_param, functional_param in zip(
|
||||
module_optim.parameters(), module_functional.parameters()
|
||||
):
|
||||
self.assertEqual(optim_param, functional_param)
|
||||
# Validate parameters are modified.
|
||||
for i, (optim_param, functional_param) in enumerate(
|
||||
zip(module_optim.parameters(), module_functional.parameters())
|
||||
):
|
||||
self.assertNotEqual(old_module_optim_params[i], optim_param)
|
||||
self.assertNotEqual(old_module_functional_params[i], functional_param)
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_WINDOWS,
|
||||
"Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137",
|
||||
)
|
||||
def test_functional_optim_parity(self):
|
||||
self._test_functional_optim_parity(SGD, 1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
Reference in New Issue
Block a user