Files
pytorch/test/optim/test_optim.py
Jane Xu c6be5d55a5 Migrate param_group testing to OptimizerInfo (#117675)
Today, our param_group testing does the equivalent of pitting weight and bias with different optimizer hyperparams and then check that the overall result is going the right direction based on maximize.

This PR introduces two tests to encompass coverage:
1. For every optimizer input (no differentiable), always force bias to have 0 weight_decay, and then check that the direction is expected. This is basically a replica to today's tests, but is more methodical as the test is a real use case.
2. To ensure that the different groups have distinct behavior, I added another test where lr is basically 0 in default group, and ensure that the param in the default group doesn't move while loss does.

Together, these tests do a better job of testing param groups than today's tests, **though we do lose some flavors**. For example, RMSProp also pits centered=True vs False across the param_groups, Adadelta has a variation on rho, and ASGD has a variation for t0. I don't think this is really a loss, as the previous test was just testing for direction and our new tests test stronger guarantees.

The leftover param group configs are used in conjunction with LRSchedulers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117675
Approved by: https://github.com/albanD
2024-01-22 23:48:46 +00:00

1700 lines
62 KiB
Python

# Owner(s): ["module: optimizer"]
import math
import unittest
import functools
import itertools
from copy import deepcopy
import torch
from torch.nn import Parameter
from torch.optim import (
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, LBFGS, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, Optimizer
)
from torch.optim.lr_scheduler import (
StepLR,
ConstantLR,
LinearLR,
ExponentialLR,
ReduceLROnPlateau,
PolynomialLR,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
gradcheck,
skipIfRocm,
skipIfTorchDynamo
)
from torch.testing._internal.common_cuda import TEST_CUDA
from typing import Dict, Any, Tuple
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
from unittest.mock import patch
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
def rosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
def drosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
@skipIfTorchDynamo("This is a TEMPORARY stopgap, see https://github.com/pytorch/pytorch/issues/103322")
class TestOptim(TestCase):
exact_dtype = True
def _test_rosenbrock_sparse(
self,
constructor,
scheduler_constructors=None,
sparse_only=False,
maximize=False,
multi_tensor=False
):
if scheduler_constructors is None:
scheduler_constructors = []
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
if multi_tensor:
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=torch.float64)]
else:
params_t = [torch.tensor([1.5, 1.5])]
params = [Parameter(param_t) for param_t in params_t]
optimizer = constructor(params)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
if not sparse_only:
params_c = [Parameter(param_t.clone()) for param_t in params_t]
optimizer_c = constructor(params_c)
solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum([param.dist(solution) for param in params])
def get_grad(param, sparse_grad):
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
# Depending on w, provide only the x or y gradient
if sparse_grad:
if w:
i = torch.LongTensor([[0, 0]])
x = grad[0]
v = torch.tensor([x / 4.0, x - x / 4.0])
else:
i = torch.LongTensor([[1, 1]])
y = grad[1]
v = torch.tensor([y - y / 4.0, y / 4.0])
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
else:
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
return grad_out
def eval(params, sparse_grad, w):
optimizer.zero_grad()
if multi_tensor:
loss = sum(rosenbrock(param) for param in params)
else:
loss = rosenbrock(params[0])
loss.backward()
grads_out = [get_grad(param, sparse_grad) for param in params]
with torch.no_grad():
params[0].grad = grads_out[0]
if multi_tensor:
params[1].grad = grads_out[1].to(dtype=torch.float64)
return loss
for i in range(2000):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params[0]))
else:
scheduler.step()
if not sparse_only:
optimizer_c.step(functools.partial(eval, params_c, False, w))
# Tolerance is increased due to floating point error from different
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
if not maximize:
self.assertLessEqual(
sum([param.dist(solution) for param in params]),
initial_dist
)
else:
self.assertGreaterEqual(
sum([rosenbrock(param) for param in params]),
sum([rosenbrock(param_t) for param_t in params_t]),
)
def _test_basic_cases_template(
self,
weight_tensor,
bias_tensor,
input_tensor,
constructor,
scheduler_constructors,
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
maximize_options = {False, constructor_accepts_maximize}
foreach_options = {False, constructor_accepts_foreach}
four_arg_constructor = constructor
if constructor_accepts_maximize and constructor_accepts_foreach:
pass
elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
self.assertFalse(foreach)
return constructor(weight, bias, maximize)
elif constructor_accepts_foreach:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize)
return constructor(weight, bias, foreach)
else:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize or foreach)
return constructor(weight, bias)
for maximize, foreach in itertools.product(maximize_options, foreach_options):
with torch.no_grad():
weight = Parameter(weight_tensor.clone().detach())
bias = Parameter(bias_tensor.clone().detach())
input = input_tensor.clone().detach().requires_grad_()
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
optimizer.step(fn)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
val_loss = fn()
scheduler.step(val_loss)
else:
scheduler.step()
if maximize:
self.assertGreater(fn().item(), initial_value)
else:
self.assertLess(fn().item(), initial_value)
def _test_basic_cases(
self,
constructor,
scheduler_constructors=None,
ignore_multidevice=False,
constructor_accepts_maximize=False,
constructor_accepts_foreach=False,
atol=None,
rtol=None,
):
if scheduler_constructors is None:
scheduler_constructors = []
def make_two_arg_constructor(
constructor, maximize: bool, foreach: bool
):
if constructor_accepts_maximize and constructor_accepts_foreach:
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
if constructor_accepts_maximize:
return lambda weight, bias: constructor(weight, bias, maximize)
if constructor_accepts_foreach:
return lambda weight, bias: constructor(weight, bias, foreach)
return constructor
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# non-contiguous parameters
self._test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# CUDA
if not torch.cuda.is_available():
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# Multi-GPU
if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
def _test_complex_optimizer(self, optimizer_constructor):
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
complex_opt = optimizer_constructor(complex_param)
real_opt = optimizer_constructor(real_param)
for _ in range(3):
complex_param.grad = torch.randn_like(complex_param)
real_param.grad = torch.view_as_real(complex_param.grad)
complex_opt.step()
real_opt.step()
self.assertEqual(torch.view_as_real(complex_param), real_param)
def _test_complex_2d(self, optimizer_constructor):
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
a1_real = a1.real.clone().detach()
a1_imag = a1.imag.clone().detach()
a1_real.requires_grad_()
a1_imag.requires_grad_()
optim1 = optimizer_constructor([a1])
optim2 = optimizer_constructor([a1_real, a1_imag])
for _ in range(10):
optim1.zero_grad()
optim2.zero_grad()
a2 = torch.complex(a1_real, a1_imag)
rosenbrock(a1).abs().backward()
rosenbrock(a2).abs().backward()
self.assertEqual(a1.grad.real, a1_real.grad)
self.assertEqual(a1.grad.imag, a1_imag.grad)
optim1.step()
optim2.step()
self.assertEqual(a1.real, a1_real)
self.assertEqual(a1.imag, a1_imag)
def _build_params_dict(self, weight, bias, **kwargs):
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
def test_sgd(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias],
lr=1e-3,
momentum=0.5,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias],
lr=1e-3,
momentum=0.5,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias],
nesterov=True,
lr=1e-3,
momentum=0.5,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_sgd_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: SGD(params, lr=4.8e-3, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: SGD(params, lr=0.0048, foreach=foreach),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
multi_tensor=foreach,
)
def test_sgd_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: SGD([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: SGD(
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
)
)
self._test_complex_optimizer(
lambda param: SGD(
[param],
lr=0.001,
nesterov=True,
momentum=1,
weight_decay=1,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: SGD(
[param],
lr=0.001,
momentum=1,
dampening=0.5,
weight_decay=1,
foreach=foreach,
)
)
def test_adam(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=torch.tensor(1e-3),
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adam_complex(self):
for foreach in (False, True):
self._test_complex_2d(functools.partial(Adam, foreach=foreach))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, amsgrad=True))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, weight_decay=0.2))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, weight_decay=0.2, amsgrad=True))
self._test_complex_2d(Adam)
self._test_complex_2d(functools.partial(
Adam, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
def test_adamw(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
lr=1e-3,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
lr=1e-3,
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
lr=torch.tensor(1e-3),
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adamw_complex(self):
self._test_complex_2d(AdamW)
self._test_complex_2d(functools.partial(
AdamW, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
for foreach in (False, True):
self._test_complex_2d(functools.partial(AdamW, foreach=foreach))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, amsgrad=True))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, weight_decay=0.2))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, weight_decay=0.2, amsgrad=True))
def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: SparseAdam(params, lr=4e-2), [], True
)
self._test_rosenbrock_sparse(
lambda params: SparseAdam(params, lr=4e-2, maximize=True),
scheduler_constructors=[],
sparse_only=True,
maximize=True,
)
# ROCm precision is too low to pass this test
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
[weight, bias], maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
[weight, bias], weight_decay=1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adadelta_complex(self):
# Handles https://github.com/pytorch/pytorch/issues/110606
self.rel_tol = 2e-2
for foreach in (False, True):
self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach))
self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach))
self._test_complex_optimizer(
lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
# NAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
def test_nadam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: NAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: NAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: NAdam(
[param],
lr=1e-1,
momentum_decay=0.01,
foreach=foreach,
)
)
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
[weight, bias],
lr=1e-1,
initial_accumulator_value=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ReduceLROnPlateau(opt),
lambda opt: ExponentialLR(opt, gamma=0.99),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adagrad_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: Adagrad(params, lr=1e-1, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: Adagrad(params, lr=0.1, foreach=foreach),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
],
multi_tensor=foreach,
)
def test_adagrad_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: Adagrad([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: Adagrad(
[param],
lr=1e-1,
initial_accumulator_value=0.1,
foreach=foreach,
)
)
def test_adamax(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adamax(
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adamax(
[weight, bias],
lr=1e-1,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(Adamax)
self._test_complex_2d(functools.partial(Adamax, foreach=True))
def test_radam(self):
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
# RAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
def test_radam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: RAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
decoupled_weight_decay=True,
foreach=foreach,
)
)
def test_rmsprop(self):
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
[weight, bias], lr=1e-2, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
self._test_complex_2d(
lambda param: RMSprop(param, centered=True, foreach=foreach)
)
self._test_complex_2d(
lambda param: RMSprop(param, momentum=0.1, foreach=foreach)
)
self._test_complex_2d(
lambda param: RMSprop(param, maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], centered=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], momentum=0.1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: RMSprop([param], maximize=True, foreach=foreach)
)
def test_asgd(self):
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: ASGD(
[weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
# Ref: https://github.com/pytorch/pytorch/issues/84560
# self._test_complex_2d(optimizer)
self._test_complex_optimizer(
lambda params: ASGD([params], foreach=foreach)
)
self._test_complex_optimizer(
lambda params: ASGD([params], maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda params: ASGD(
[params], maximize=True, weight_decay=0.9, foreach=foreach
)
)
self._test_complex_optimizer(
lambda params: ASGD(
[params], maximize=False, weight_decay=0.9, foreach=foreach
)
)
@skipIfRocm
@skipIfTorchDynamo()
def test_rprop(self):
is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(
0
) == (8, 6)
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Rprop(
[weight, bias], lr=2e-4, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
self._test_complex_optimizer(
lambda param: Rprop([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: Rprop(
[param], lr=0.001, maximize=True, foreach=foreach
)
)
def test_lbfgs(self):
self._test_basic_cases(
lambda weight, bias: LBFGS([weight, bias]), ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: LBFGS(
[weight, bias], line_search_fn="strong_wolfe"
),
ignore_multidevice=True,
)
def test_lbfgs_returns_consistent_type(self):
params = [torch.randn(10, 5), torch.randn(10)]
opt1 = LBFGS(params, 0.01, tolerance_grad=math.inf)
opt2 = LBFGS(params, 0.01, tolerance_grad=-math.inf)
def closure():
return torch.tensor([10])
res1 = opt1.step(closure)
res2 = opt2.step(closure)
self.assertEqual(type(res1), type(res2))
def test_fused_optimizer_does_not_step_if_foundinf(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is required.")
from torch.optim import adam, adamw, sgd
num_tensors = 5
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
params, grads, exp_avgs, exp_avg_sqs = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4))
prev_params = [t.clone().detach() for t in params]
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
functional_optim(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=True,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-2,
weight_decay=0.0,
eps=1e-8,
maximize=False,
grad_scale=grad_scale,
found_inf=found_inf,
)
self.assertEqual(
state_steps,
[
torch.ones((), dtype=torch.float32, device="cuda")
for _ in range(num_tensors)
],
)
self.assertEqual(params, prev_params)
else:
for momentum in (0.0, 0.1):
params, d_p_list, momentum_buffer_list = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(3))
if momentum == 0.0:
momentum_buffer_list = [None for _ in range(num_tensors)]
prev_params = [t.clone().detach() for t in params]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
sgd.sgd(
params,
d_p_list,
momentum_buffer_list,
has_sparse_grad=False,
foreach=False,
fused=True,
grad_scale=grad_scale,
found_inf=found_inf,
weight_decay=0.0,
momentum=momentum,
lr=0.01,
dampening=0.0,
nesterov=False,
maximize=False,
)
@skipIfTorchDynamo()
def test_post_hook(self):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 2
hook_handle = opt.register_step_post_hook(post_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 6)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 6)
@skipIfTorchDynamo()
def test_pre_hook(self):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 5
hook_handle = opt.register_step_pre_hook(pre_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 9)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 9)
@skipIfTorchDynamo()
def test_pre_and_post_hook(self):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(2)
params = [torch.Tensor([1, 1])]
opt1 = SGD(params, lr=0.001)
opt2 = Adam(params, lr=0.01)
data = []
# register global hooks to both optimizers
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
# register local hooks
first_pre_handle = opt1.register_step_pre_hook(local_pre_hook)
first_post_handle = opt1.register_step_post_hook(local_post_hook)
second_pre_handle = opt2.register_step_pre_hook(local_pre_hook)
second_post_handle = opt2.register_step_post_hook(local_post_hook)
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5])
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
# remove all hooks
global_pre_handle.remove()
global_post_handle.remove()
first_pre_handle.remove()
first_post_handle.remove()
second_pre_handle.remove()
second_post_handle.remove()
opt1.step()
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True
def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)
def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)
def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()
# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
# for `gradcheck` to correctly track the state tensors as function inputs
# because otherwise it can't unpack the values in the `opt_differentiable_state`
# dict
p = p.clone()
p.grad = grad
opt_differentiable_state = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in opt_differentiable_state.items()
}
opt = opt_class([p], **kwargs)
opt.state[p].update(opt_differentiable_state)
opt.step()
return (p,) + tuple(
v
for v in opt.state[p].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):
def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
gradcheck(
_diff_fn,
(
p,
grad,
state,
SGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adam,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step"] = 0
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["momentum_buffer"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# This can cause issues with large values and nan due to sqrt ops
state["grad_avg"] = 1e-2 * torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RMSprop,
{
"lr": 0.9,
"maximize": True,
"momentum": 0.9,
"differentiable": True,
"centered": True,
"weight_decay": 0.1,
},
*state.values(),
),
)
def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adadelta,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adagrad,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adamax,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now")
def test_asgd(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
# and so they shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
ASGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Rprop,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adamw(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
AdamW,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_nadam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
def test_radam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{"lr": 0.9, "weight_decay": 0.1, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_defaults_changed_to_foreach(self):
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
asgd, adamax, adadelta, adagrad)
multi_optims = ((Adam, adam, "_multi_tensor_adam"),
(AdamW, adamw, "_multi_tensor_adamw"),
(NAdam, nadam, "_multi_tensor_nadam"),
(SGD, sgd, "_multi_tensor_sgd"),
(RAdam, radam, "_multi_tensor_radam"),
(RMSprop, rmsprop, "_multi_tensor_rmsprop"),
(Rprop, rprop, "_multi_tensor_rprop"),
(ASGD, asgd, "_multi_tensor_asgd"),
(Adamax, adamax, "_multi_tensor_adamax"),
(Adadelta, adadelta, "_multi_tensor_adadelta"),
(Adagrad, adagrad, "_multi_tensor_adagrad"),)
model = torch.nn.Linear(5, 5)
model.to(dtype=torch.float64, device="cuda")
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
for opt, mod, func in multi_optims:
defaults = {}
if opt == SGD:
defaults["lr"] = 1e-2
optimizer = opt(model.parameters(), **defaults)
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
with patch.object(mod, func) as mocked_foreach_impl:
optimizer.step()
self.assertTrue(mocked_foreach_impl.called)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")