mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move LRScheduler integration tests to OptimizerInfo (#123134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123134 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
12e36dc1df
commit
cb8fc30e4a
@ -8,14 +8,7 @@ from torch.nn import Parameter
|
||||
from torch.optim import (
|
||||
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
|
||||
)
|
||||
from torch.optim.lr_scheduler import (
|
||||
StepLR,
|
||||
ConstantLR,
|
||||
LinearLR,
|
||||
ExponentialLR,
|
||||
ReduceLROnPlateau,
|
||||
PolynomialLR,
|
||||
)
|
||||
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
load_tests,
|
||||
@ -145,7 +138,6 @@ class TestOptim(TestCase):
|
||||
bias_tensor,
|
||||
input_tensor,
|
||||
constructor,
|
||||
scheduler_constructors,
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=False,
|
||||
):
|
||||
@ -179,10 +171,6 @@ class TestOptim(TestCase):
|
||||
bias = Parameter(bias_tensor.clone().detach())
|
||||
input = input_tensor.clone().detach().requires_grad_()
|
||||
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
@ -196,12 +184,6 @@ class TestOptim(TestCase):
|
||||
initial_value = fn().item()
|
||||
for _ in range(200):
|
||||
optimizer.step(fn)
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
val_loss = fn()
|
||||
scheduler.step(val_loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
if maximize:
|
||||
self.assertGreater(fn().item(), initial_value)
|
||||
else:
|
||||
@ -211,19 +193,14 @@ class TestOptim(TestCase):
|
||||
def _test_basic_cases(
|
||||
self,
|
||||
constructor,
|
||||
scheduler_constructors=None,
|
||||
constructor_accepts_maximize=False,
|
||||
constructor_accepts_foreach=False,
|
||||
):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
@ -233,7 +210,6 @@ class TestOptim(TestCase):
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
scheduler_constructors,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
@ -245,7 +221,6 @@ class TestOptim(TestCase):
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
constructor,
|
||||
scheduler_constructors,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
@ -257,90 +232,11 @@ class TestOptim(TestCase):
|
||||
torch.randn(10).cuda(1),
|
||||
torch.randn(5).cuda(0),
|
||||
constructor,
|
||||
scheduler_constructors,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
|
||||
|
||||
def _build_params_dict(self, weight, bias, **kwargs):
|
||||
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
|
||||
|
||||
def test_sgd(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
scheduler_constructors=[
|
||||
lambda opt: LinearLR(
|
||||
opt, start_factor=0.4, end_factor=0.8, total_iters=4
|
||||
)
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
scheduler_constructors=[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: LinearLR(
|
||||
opt, start_factor=0.4, end_factor=0.6, total_iters=4
|
||||
),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.99),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sgd_sparse(self):
|
||||
for foreach in (False, True):
|
||||
self._test_rosenbrock_sparse(
|
||||
@ -354,110 +250,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
|
||||
def test_adam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[
|
||||
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=torch.tensor(1e-3),
|
||||
maximize=maximize,
|
||||
foreach=False, # foreach for lr tensors tested in multi configs
|
||||
),
|
||||
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_adamw(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
@ -484,80 +276,6 @@ class TestOptim(TestCase):
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
# ROCm precision is too low to pass this test
|
||||
def test_adadelta(self):
|
||||
# Handles https://github.com/pytorch/pytorch/issues/69698
|
||||
self.rel_tol = 4e-3
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adadelta(
|
||||
self._build_params_dict(weight, bias, rho=0.95),
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def test_nadam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=0.1,
|
||||
momentum_decay=6e-3,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
# NAdamW tests
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=0.1,
|
||||
momentum_decay=6e-3,
|
||||
decoupled_weight_decay=True,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_adagrad(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[lambda opt: ReduceLROnPlateau(opt)],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
[
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.99),
|
||||
],
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_adagrad_sparse(self):
|
||||
for foreach in (False, True):
|
||||
@ -575,30 +293,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
|
||||
def test_radam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, foreach=foreach
|
||||
),
|
||||
[
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
# RAdamW tests
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
|
||||
),
|
||||
[
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
||||
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
||||
# for `gradcheck` to correctly track the state tensors as function inputs
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: optimizer"]
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import tempfile
|
||||
from typing import Any, Dict, Tuple
|
||||
@ -7,6 +8,8 @@ import unittest
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
||||
@ -30,6 +33,7 @@ def rosenbrock(tensor):
|
||||
x, y = tensor
|
||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
||||
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestOptimRenewed(TestCase):
|
||||
|
||||
@ -69,11 +73,13 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
|
||||
@parametrize("contiguous", [True, False])
|
||||
@parametrize("with_lrsched", [True, False])
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous):
|
||||
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous, with_lrsched):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
|
||||
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
if contiguous:
|
||||
@ -85,6 +91,7 @@ class TestOptimRenewed(TestCase):
|
||||
input = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
schedulers = [s(optimizer) for s in (schedulers_constructor if schedulers_constructor else [])]
|
||||
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
@ -99,7 +106,12 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
initial_value = closure().item()
|
||||
for _ in range(20):
|
||||
optimizer.step(closure)
|
||||
loss = optimizer.step(closure)
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
if optim_input.kwargs.get("maximize", False):
|
||||
self.assertGreater(closure().item(), initial_value)
|
||||
@ -109,22 +121,26 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@parametrize("with_lrsched", [True, False])
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info, with_lrsched):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
|
||||
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
|
||||
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
||||
input = torch.randn(5, device="cuda:0", dtype=dtype)
|
||||
inpt = torch.randn(5, device="cuda:0", dtype=dtype)
|
||||
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
schedulers = [s(optimizer) for s in (schedulers_constructor if schedulers_constructor else [])]
|
||||
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input).cuda(1) + bias).pow(2).sum()
|
||||
loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_info.only_supports_sparse_grads:
|
||||
# For this test, we naively convert the Tensor layout, which we know does
|
||||
@ -135,7 +151,12 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
initial_value = closure().item()
|
||||
for _ in range(20):
|
||||
optimizer.step(closure)
|
||||
loss = optimizer.step(closure)
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
if optim_input.kwargs.get("maximize", False):
|
||||
self.assertGreater(closure().item(), initial_value)
|
||||
@ -143,6 +164,41 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_param_group_with_lrscheduler_goes_right_direction(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
for schedulers_c in optim_info.scheduler_inputs:
|
||||
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
||||
inpt = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": 0.01}])
|
||||
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
|
||||
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(inpt) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_info.only_supports_sparse_grads:
|
||||
# For this test, we naively convert the Tensor layout, which we know does
|
||||
# NOT represent the expected use case for optims like SparseAdam!
|
||||
weight.grad = weight.grad.to_sparse()
|
||||
bias.grad = bias.grad.to_sparse()
|
||||
return loss
|
||||
|
||||
initial_value = closure().item()
|
||||
for _ in range(20):
|
||||
loss = optimizer.step(closure)
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(loss)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@skipMPS
|
||||
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
|
||||
def test_complex(self, device, dtype, optim_info):
|
||||
|
@ -26,6 +26,14 @@ from torch.optim import (
|
||||
SGD,
|
||||
SparseAdam,
|
||||
)
|
||||
from torch.optim.lr_scheduler import (
|
||||
ConstantLR,
|
||||
ExponentialLR,
|
||||
LinearLR,
|
||||
PolynomialLR,
|
||||
ReduceLROnPlateau,
|
||||
StepLR,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import tol, toleranceOverride
|
||||
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -101,6 +109,17 @@ class OptimizerInfo:
|
||||
# to the test using the OptimizerInfo. OptimizerInput.params is likely None.
|
||||
# Can optionally take in device to filter out certain unsupported configs
|
||||
optim_inputs_func,
|
||||
# Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
|
||||
# LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
|
||||
# We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
|
||||
# LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
|
||||
# A few optimizers like SGD and Adam will test more LRSchedulers.
|
||||
scheduler_inputs=(
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
),
|
||||
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
|
||||
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
|
||||
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
||||
@ -122,6 +141,7 @@ class OptimizerInfo:
|
||||
):
|
||||
self.optim_cls = optim_cls
|
||||
self.optim_inputs_func = optim_inputs_func
|
||||
self.scheduler_inputs = scheduler_inputs
|
||||
self.supported_impls = supported_impls
|
||||
self.supports_sparse_on = supports_sparse_on
|
||||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||||
@ -1184,6 +1204,24 @@ optim_db: List[OptimizerInfo] = [
|
||||
OptimizerInfo(
|
||||
Adam,
|
||||
optim_inputs_func=optim_inputs_func_adam,
|
||||
scheduler_inputs=(
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
||||
[
|
||||
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
],
|
||||
[
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||||
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
skips=(
|
||||
@ -1200,6 +1238,11 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
@ -1571,6 +1614,11 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("Does not support param groups"),
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
@ -1935,6 +1983,31 @@ optim_db: List[OptimizerInfo] = [
|
||||
OptimizerInfo(
|
||||
SGD,
|
||||
optim_inputs_func=optim_inputs_func_sgd,
|
||||
scheduler_inputs=(
|
||||
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
|
||||
[
|
||||
lambda opt: LinearLR(
|
||||
opt, start_factor=0.4, end_factor=0.8, total_iters=4
|
||||
)
|
||||
],
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: LinearLR(
|
||||
opt, start_factor=0.4, end_factor=0.6, total_iters=4
|
||||
),
|
||||
],
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.99),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||||
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt),
|
||||
],
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_sgd,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_sparse_on=("cpu", "cuda"),
|
||||
@ -1953,6 +2026,11 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
@ -2092,6 +2170,11 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
|
Reference in New Issue
Block a user