Move LRScheduler integration tests to OptimizerInfo (#123134)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123134
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-04-02 12:20:55 -07:00
committed by PyTorch MergeBot
parent 12e36dc1df
commit cb8fc30e4a
3 changed files with 148 additions and 315 deletions

View File

@ -8,14 +8,7 @@ from torch.nn import Parameter
from torch.optim import (
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
)
from torch.optim.lr_scheduler import (
StepLR,
ConstantLR,
LinearLR,
ExponentialLR,
ReduceLROnPlateau,
PolynomialLR,
)
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
@ -145,7 +138,6 @@ class TestOptim(TestCase):
bias_tensor,
input_tensor,
constructor,
scheduler_constructors,
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
@ -179,10 +171,6 @@ class TestOptim(TestCase):
bias = Parameter(bias_tensor.clone().detach())
input = input_tensor.clone().detach().requires_grad_()
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
def fn():
optimizer.zero_grad()
@ -196,12 +184,6 @@ class TestOptim(TestCase):
initial_value = fn().item()
for _ in range(200):
optimizer.step(fn)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
val_loss = fn()
scheduler.step(val_loss)
else:
scheduler.step()
if maximize:
self.assertGreater(fn().item(), initial_value)
else:
@ -211,19 +193,14 @@ class TestOptim(TestCase):
def _test_basic_cases(
self,
constructor,
scheduler_constructors=None,
constructor_accepts_maximize=False,
constructor_accepts_foreach=False,
):
if scheduler_constructors is None:
scheduler_constructors = []
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
@ -233,7 +210,6 @@ class TestOptim(TestCase):
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
@ -245,7 +221,6 @@ class TestOptim(TestCase):
torch.randn(10).cuda(),
torch.randn(5).cuda(),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
@ -257,90 +232,11 @@ class TestOptim(TestCase):
torch.randn(10).cuda(1),
torch.randn(5).cuda(0),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
def _build_params_dict(self, weight, bias, **kwargs):
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
def test_sgd(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_sgd_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
@ -354,110 +250,6 @@ class TestOptim(TestCase):
)
def test_adam(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=torch.tensor(1e-3),
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adamw(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
@ -484,80 +276,6 @@ class TestOptim(TestCase):
maximize=True,
)
# ROCm precision is too low to pass this test
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
# NAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ReduceLROnPlateau(opt),
lambda opt: ExponentialLR(opt, gamma=0.99),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adagrad_sparse(self):
for foreach in (False, True):
@ -575,30 +293,6 @@ class TestOptim(TestCase):
)
def test_radam(self):
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
# RAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
# for `gradcheck` to correctly track the state tensors as function inputs

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: optimizer"]
import functools
import itertools
import math
import tempfile
from typing import Any, Dict, Tuple
@ -7,6 +8,8 @@ import unittest
from copy import deepcopy
from unittest.mock import patch
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
from torch.optim import Optimizer, SGD
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
@ -30,6 +33,7 @@ def rosenbrock(tensor):
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
@markDynamoStrictTest
class TestOptimRenewed(TestCase):
@ -69,11 +73,13 @@ class TestOptimRenewed(TestCase):
@parametrize("contiguous", [True, False])
@parametrize("with_lrsched", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous):
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous, with_lrsched):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
if contiguous:
@ -85,6 +91,7 @@ class TestOptimRenewed(TestCase):
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
schedulers = [s(optimizer) for s in (schedulers_constructor if schedulers_constructor else [])]
def closure():
optimizer.zero_grad()
@ -99,7 +106,12 @@ class TestOptimRenewed(TestCase):
initial_value = closure().item()
for _ in range(20):
optimizer.step(closure)
loss = optimizer.step(closure)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
if optim_input.kwargs.get("maximize", False):
self.assertGreater(closure().item(), initial_value)
@ -109,22 +121,26 @@ class TestOptimRenewed(TestCase):
@onlyCUDA
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@parametrize("with_lrsched", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info, with_lrsched):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
schedulers_constructors = optim_info.scheduler_inputs if with_lrsched else [None]
for optim_input, schedulers_constructor in itertools.product(optim_inputs, schedulers_constructors):
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
input = torch.randn(5, device="cuda:0", dtype=dtype)
inpt = torch.randn(5, device="cuda:0", dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
schedulers = [s(optimizer) for s in (schedulers_constructor if schedulers_constructor else [])]
def closure():
optimizer.zero_grad()
loss = (weight.mv(input).cuda(1) + bias).pow(2).sum()
loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
@ -135,7 +151,12 @@ class TestOptimRenewed(TestCase):
initial_value = closure().item()
for _ in range(20):
optimizer.step(closure)
loss = optimizer.step(closure)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
if optim_input.kwargs.get("maximize", False):
self.assertGreater(closure().item(), initial_value)
@ -143,6 +164,41 @@ class TestOptimRenewed(TestCase):
self.assertLess(closure().item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_param_group_with_lrscheduler_goes_right_direction(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
for schedulers_c in optim_info.scheduler_inputs:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
inpt = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": 0.01}])
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
def closure():
optimizer.zero_grad()
loss = (weight.mv(inpt) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
return loss
initial_value = closure().item()
for _ in range(20):
loss = optimizer.step(closure)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
self.assertLess(closure().item(), initial_value)
@skipMPS
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
def test_complex(self, device, dtype, optim_info):

View File

@ -26,6 +26,14 @@ from torch.optim import (
SGD,
SparseAdam,
)
from torch.optim.lr_scheduler import (
ConstantLR,
ExponentialLR,
LinearLR,
PolynomialLR,
ReduceLROnPlateau,
StepLR,
)
from torch.testing._internal.common_device_type import tol, toleranceOverride
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_utils import (
@ -101,6 +109,17 @@ class OptimizerInfo:
# to the test using the OptimizerInfo. OptimizerInput.params is likely None.
# Can optionally take in device to filter out certain unsupported configs
optim_inputs_func,
# Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
# LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
# We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
# LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
# A few optimizers like SGD and Adam will test more LRSchedulers.
scheduler_inputs=(
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
supported_impls: Tuple[str] = ("foreach", "differentiable"),
@ -122,6 +141,7 @@ class OptimizerInfo:
):
self.optim_cls = optim_cls
self.optim_inputs_func = optim_inputs_func
self.scheduler_inputs = scheduler_inputs
self.supported_impls = supported_impls
self.supports_sparse_on = supports_sparse_on
self.only_supports_sparse_grads = only_supports_sparse_grads
@ -1184,6 +1204,24 @@ optim_db: List[OptimizerInfo] = [
OptimizerInfo(
Adam,
optim_inputs_func=optim_inputs_func_adam,
scheduler_inputs=(
[lambda opt: ExponentialLR(opt, gamma=0.9)],
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
skips=(
@ -1200,6 +1238,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo(
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
@ -1571,6 +1614,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
),
),
OptimizerInfo(
@ -1935,6 +1983,31 @@ optim_db: List[OptimizerInfo] = [
OptimizerInfo(
SGD,
optim_inputs_func=optim_inputs_func_sgd,
scheduler_inputs=(
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
optim_error_inputs_func=optim_error_inputs_func_sgd,
supported_impls=("foreach", "differentiable", "fused"),
supports_sparse_on=("cpu", "cuda"),
@ -1953,6 +2026,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("initial_value is incorrect in dynamo, see #123202"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo(
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
@ -2092,6 +2170,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",