Move sparse tests to TestOptimRenewed (#123146)

This is the last of the old TestOptim! With this change, everything will be migrated to use OptimizerInfo. Our sparse support is...well, sparse, and the tests try to best encapsulate which configs actually work. Note that support_sparse is actually just supports sparse grads...we don't test sparse params.

1. This PR fixes a bug in Adagrad multi_tensor with maximize by passing the correct value of maximize (vs False everytime) when sparse values are present.

2. This PR does improve coverage. There used to only be 2 configs each, and now we have the following configs for:

Adagrad:
```
python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Adagrad
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'lr': 0.1}    <--- this and above are CPU
.{'foreach': False, 'lr': 0.1}
{'foreach': True, 'lr': 0.1}
{'maximize': True, 'foreach': False, 'lr': 0.1}
{'maximize': True, 'foreach': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': False, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': True, 'lr': 0.1}
.
----------------------------------------------------------------------
Ran 2 tests in 227.744s

OK
```

SGD
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_SGD
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
{'dampening': 0.5, 'lr': 0.0048}
.{'foreach': False, 'lr': 0.0048}
{'foreach': True, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': False, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': True, 'lr': 0.0048}
.
----------------------------------------------------------------------
Ran 2 tests in 112.801s

OK
```

SparseAdam
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Sparse
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.04}
.{'maximize': True, 'lr': 0.04}
.
----------------------------------------------------------------------
Ran 2 tests in 35.113s

OK
```

Fixes #103322. A side quest in this migration was to re-enable and track dynamo issues as they trigger on the optim tests, which will be complete from this PR. New tests may add more things to track in dynamo, but there is now an established system for doing so, and dynamo is either enabled or a bug is tracked for every migrated test in TestOptimRenewed.

Next steps:
Remove the hyperparameter constraints in common_optimizer.py defined by metadata_for_sparse (other than LR, which seems handpicked for the tests to actually pass). Doing this requires adding more sparse functionality.

Add more tests!

Maybe add more optimizers!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123146
Approved by: https://github.com/albanD
ghstack dependencies: #123134, #123139
This commit is contained in:
Jane Xu
2024-04-02 12:20:58 -07:00
committed by PyTorch MergeBot
parent f2838c99a0
commit d7fe0603a1
4 changed files with 174 additions and 163 deletions

View File

@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
from torch.optim import Optimizer, SGD
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
from optim.test_optim import TestDifferentiableOptimizer # noqa: F401
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
from torch.nn import Parameter
@ -34,6 +34,12 @@ def rosenbrock(tensor):
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
def drosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
@markDynamoStrictTest
class TestOptimRenewed(TestCase):
@ -249,6 +255,129 @@ class TestOptimRenewed(TestCase):
self.assertEqual(bias, bias_c)
@parametrize("with_lrsched", [True, False])
@optims([o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads], dtypes=[torch.float64])
def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
# Fused impls do not support sparse gradients
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable", "fused"))
kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
if with_lrsched and len(schedulers_constructors) == 0:
return
supported_inputs = []
if len(kwarg_updates) != 0:
seen = set()
for i in all_optim_inputs:
for k in kwarg_updates:
if k in i.kwargs:
del i.kwargs[k]
hashable_kwargs = tuple(sorted(i.kwargs.items()))
if len(i.kwargs) > 0 and hashable_kwargs not in seen:
supported_inputs.append(i)
seen.add(hashable_kwargs)
if "lr" in kwarg_updates:
i.kwargs["lr"] = kwarg_updates["lr"]
else:
supported_inputs = all_optim_inputs
for optim_input in supported_inputs:
kwargs = optim_input.kwargs
multi_tensor = kwargs.get("foreach", False)
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
if multi_tensor:
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=dtype)]
else:
params_t = [torch.tensor([1.5, 1.5])]
params = [Parameter(param_t) for param_t in params_t]
optimizer = optim_cls(params, **kwargs)
schedulers = [s(optimizer) for s in (schedulers_constructors if with_lrsched else [])]
if not optim_info.only_supports_sparse_grads:
params_c = [Parameter(param_t.clone()) for param_t in params_t]
optimizer_c = optim_cls(params_c, **kwargs)
schedulers_c = [s(optimizer_c) for s in (schedulers_constructors if with_lrsched else [])]
solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum([param.dist(solution) for param in params])
def get_grad(param, sparse_grad, w):
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
# Depending on w, provide only the x or y gradient
if sparse_grad:
if w:
i = torch.tensor([[0, 0]], dtype=torch.int64)
x = grad[0]
v = torch.tensor([x / 4.0, x - x / 4.0])
else:
i = torch.tensor([[1, 1]], dtype=torch.int64)
y = grad[1]
v = torch.tensor([y - y / 4.0, y / 4.0])
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
else:
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
return grad_out
def eval(params, sparse_grad, w):
optimizer.zero_grad()
if multi_tensor:
loss = sum(rosenbrock(param) for param in params)
else:
loss = rosenbrock(params[0])
loss.backward()
grads_out = [get_grad(param, sparse_grad, w) for param in params]
with torch.no_grad():
params[0].grad = grads_out[0]
if multi_tensor:
params[1].grad = grads_out[1].to(dtype=dtype)
return loss
for i in range(1800):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params[0]))
else:
scheduler.step()
if not optim_info.only_supports_sparse_grads:
optimizer_c.step(functools.partial(eval, params_c, False, w))
for scheduler in schedulers_c:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params_c[0]))
else:
scheduler.step()
# Tolerance is increased due to floating point error from different
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
if not kwargs.get("maximize", False):
self.assertLessEqual(
sum([param.dist(solution) for param in params]),
initial_dist
)
else:
self.assertGreaterEqual(
sum([rosenbrock(param) for param in params]),
sum([rosenbrock(param_t) for param_t in params_t]),
)
@skipMPS
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
def test_complex(self, device, dtype, optim_info):