mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove old optimizer tests (#120257)
Removes old tests now that all configs are covered in test_compiled_optimizers.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/120257 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4cef25a1e
commit
65519d183b
@ -6,7 +6,6 @@ import functools
|
||||
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,93 +14,6 @@ import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch.nn import Parameter
|
||||
|
||||
input = torch.ones([10, 10])
|
||||
model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)])
|
||||
model(input).sum().backward()
|
||||
|
||||
|
||||
def get_optimizer_step(opt, closure=None):
|
||||
# run the patcher so that step has the expected structure
|
||||
torch._dynamo.eval_frame.TorchPatcher.patch()
|
||||
|
||||
# unwrap step TWICE to avoid a deliberate graph break due to a limitation of
|
||||
# functionalization/no_grad detection--see the [Note on graph break] in optimizer.py
|
||||
# This ignores the _use_grad_if_differentiable wrapper, which is fine for now as
|
||||
# dynamo does not support differentiable optimizers anyway.
|
||||
# This _also_ ignores the outer profiling hook wrapper, which may NOT be fine.
|
||||
step_fn = opt.step.__wrapped__.__wrapped__
|
||||
if closure is not None:
|
||||
|
||||
def fn():
|
||||
step_fn(opt, closure)
|
||||
|
||||
else:
|
||||
|
||||
def fn():
|
||||
step_fn(opt)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def make_test(optim_cls, closure=None, **kwargs):
|
||||
# Remove this conditional when #118230 is fixed
|
||||
if optim_cls.__name__ == "Adamax":
|
||||
kwargs["foreach"] = True
|
||||
|
||||
opt = optim_cls(model.parameters(), **kwargs)
|
||||
|
||||
def test_fn(self):
|
||||
nonlocal opt
|
||||
|
||||
fn = get_optimizer_step(opt, closure=closure)
|
||||
|
||||
with torch.set_grad_enabled(False):
|
||||
torch.compile(fn, backend="eager", fullgraph=True)()
|
||||
|
||||
return test_fn
|
||||
|
||||
|
||||
class OptimizerTests(torch._dynamo.test_case.TestCase):
|
||||
test_sgd = make_test(torch.optim.SGD, lr=0.01)
|
||||
# lgbfs has data-dependent control and internally iterates
|
||||
# calling the closure
|
||||
# TODO mlazos: re-enable once we have latest pytorch with FakeTensor fix #497
|
||||
# test_lbfgs = make_test(
|
||||
# torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum()
|
||||
# )
|
||||
|
||||
# Has data dependent control for rectification (needs symint)
|
||||
# RAdam has data-dependent control which breaks the graph;
|
||||
# furthermore, the break is inside a for loop, so we bail on the frame
|
||||
# entirely. This is basically an xfail; if the frame count goes up
|
||||
# you done good
|
||||
# test_radam = unittest.skipIf(IS_FBCODE, "TypeError: _use_grad() missing")(
|
||||
# make_test(torch.optim.RAdam, exp_graph_count=0)
|
||||
# )
|
||||
|
||||
|
||||
# exclude SparseAdam because other areas of the stack don't support it yet
|
||||
# the others are handled specially above
|
||||
exclude = {
|
||||
"SGD", # Handled above
|
||||
"Optimizer",
|
||||
"SparseAdam", # Unsupported
|
||||
"LBFGS", # Unsupported
|
||||
"RAdam", # Has data dependent control for rectification (needs symint)
|
||||
}
|
||||
|
||||
optimizers = [
|
||||
opt
|
||||
for opt in torch.optim.__dict__.values()
|
||||
if inspect.isclass(opt)
|
||||
and issubclass(opt, torch.optim.Optimizer)
|
||||
and opt.__name__ not in exclude
|
||||
]
|
||||
|
||||
|
||||
for opt in optimizers:
|
||||
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
|
||||
|
||||
|
||||
class MyOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, params):
|
||||
@ -180,7 +92,6 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
|
||||
tensor = torch.randn(5, 5, dtype=dtype)
|
||||
params = Parameter(tensor.detach().clone(), requires_grad=False)
|
||||
opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
|
||||
print(params, opt_params)
|
||||
|
||||
optim = MyOptimizer([params])
|
||||
optim.step()
|
||||
@ -188,7 +99,6 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
|
||||
opt_optim = MyOptimizer([opt_params])
|
||||
opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
|
||||
opt_step()
|
||||
print(params, opt_params)
|
||||
|
||||
self.assertEqual(params, opt_params)
|
||||
|
||||
|
Reference in New Issue
Block a user