Add a tensor lr test for optimizers (#123139)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123139
Approved by: https://github.com/albanD
ghstack dependencies: #123134
This commit is contained in:
Jane Xu
2024-04-02 12:20:56 -07:00
committed by PyTorch MergeBot
parent cb8fc30e4a
commit f2838c99a0
3 changed files with 128 additions and 120 deletions

View File

@ -1,7 +1,6 @@
# Owner(s): ["module: optimizer"]
import functools
import itertools
import torch
from torch.nn import Parameter
@ -132,110 +131,6 @@ class TestOptim(TestCase):
sum([rosenbrock(param_t) for param_t in params_t]),
)
def _test_basic_cases_template(
self,
weight_tensor,
bias_tensor,
input_tensor,
constructor,
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
maximize_options = {False, constructor_accepts_maximize}
foreach_options = {False, constructor_accepts_foreach}
four_arg_constructor = constructor
if constructor_accepts_maximize and constructor_accepts_foreach:
pass
elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
self.assertFalse(foreach)
return constructor(weight, bias, maximize)
elif constructor_accepts_foreach:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize)
return constructor(weight, bias, foreach)
else:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize or foreach)
return constructor(weight, bias)
for maximize, foreach in itertools.product(maximize_options, foreach_options):
with torch.no_grad():
weight = Parameter(weight_tensor.clone().detach())
bias = Parameter(bias_tensor.clone().detach())
input = input_tensor.clone().detach().requires_grad_()
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
optimizer.step(fn)
if maximize:
self.assertGreater(fn().item(), initial_value)
else:
self.assertLess(fn().item(), initial_value)
def _test_basic_cases(
self,
constructor,
constructor_accepts_maximize=False,
constructor_accepts_foreach=False,
):
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# non-contiguous parameters
self._test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# CUDA
if not torch.cuda.is_available():
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(),
constructor,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# Multi-GPU
if not torch.cuda.device_count() > 1:
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0),
constructor,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
def test_sgd_sparse(self):
for foreach in (False, True):
@ -250,21 +145,6 @@ class TestOptim(TestCase):
)
def test_adamw(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
lr=torch.tensor(1e-3),
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: SparseAdam(params, lr=4e-2), [], True

View File

@ -199,6 +199,56 @@ class TestOptimRenewed(TestCase):
self.assertLess(closure().item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_tensor_lr(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
weight_c = weight.clone().detach().requires_grad_(True)
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
bias_c = bias.clone().detach().requires_grad_(True)
inpt = torch.randn(5, device=device, dtype=dtype)
kwargs = optim_input.kwargs
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
optimizer_r = optim_cls([weight, bias], **kwargs)
try:
kwargs["lr"] = torch.tensor(kwargs["lr"])
optimizer = optim_cls([weight_c, bias_c], **kwargs)
except ValueError as e:
self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
continue
def closure(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
w.grad = w.grad.to_sparse()
b.grad = b.grad.to_sparse()
return loss
for _ in range(5):
if optim_info.step_requires_closure:
optimizer_r.step(functools.partial(closure, optimizer_r, weight, bias, inpt))
optimizer.step(functools.partial(closure, optimizer, weight_c, bias_c, inpt))
else:
closure(optimizer_r, weight, bias, inpt)
closure(optimizer, weight_c, bias_c, inpt)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
@skipMPS
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
def test_complex(self, device, dtype, optim_info):

View File

@ -2,6 +2,7 @@
import functools
import itertools
import sys
import unittest
from copy import deepcopy
from enum import Enum
@ -1047,6 +1048,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115679"
@ -1150,6 +1157,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115607"
@ -1250,6 +1263,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
@ -1327,6 +1346,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
"TestOptimRenewed",
@ -1441,6 +1466,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
@ -1512,6 +1543,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"See discrepancy in https://github.com/pytorch/pytorch/issues/115607"
@ -1619,6 +1656,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
),
),
OptimizerInfo(
@ -1647,6 +1690,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
@ -1733,6 +1782,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
@ -1828,6 +1883,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115679"
@ -1922,6 +1983,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115679"
@ -2031,6 +2098,12 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
"TestOptimRenewed",
"test_tensor_lr",
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
@ -2148,6 +2221,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_tensor_lr",
),
DecorateInfo(
unittest.skip(
"SparseAdam does not support dense gradients, see #116507"