mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Migrate forloop directional tests to OptimizerInfo (#117410)
This PR is another step towards modernizing our optimizer tests by tackling the simplest foreach tests. The replaced tests are now removed in `test/optim/test_optim.py`. **Changes in coverage?** Yes! - This PR _decreases_ coverage (!!!!) by only checking the direction on the forloop implementations vs both the forloop and foreach. Why? I believe it should be sufficient to check the forloop only, as the foreach parity is already checked in the `foreach_matches_forloop` test. - This PR also _increases_ coverage for SparseAdam with contiguous params on CUDA, which was previously forbidden due to an old old bug that has since been fixed. What will it take to fully remove `test_basic_cases`? - We need to flavor the tests with LRSchedulers - Testing for param groups --> which all just distinguish between lrs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117410 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b671ce486
commit
fc30c4d769
@ -49,6 +49,7 @@ time python test/run_test.py --verbose -i distributed/tensor/parallel/test_tp_ex
|
||||
|
||||
# Other tests
|
||||
time python test/run_test.py --verbose -i test_cuda_primary_ctx
|
||||
time python test/run_test.py --verbose -i test_optim -- -k optimizers_with_varying_tensors
|
||||
time python test/run_test.py --verbose -i test_optim -- -k test_forloop_goes_right_direction_multigpu
|
||||
time python test/run_test.py --verbose -i test_optim -- -k test_mixed_device_dtype
|
||||
time python test/run_test.py --verbose -i test_foreach -- -k test_tensors_grouping
|
||||
assert_git_not_dirty
|
||||
|
@ -296,9 +296,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
|
||||
Tensor dense_buffer = dense.to(commonDtype);
|
||||
Tensor values = sparse._values().to(commonDtype);
|
||||
|
||||
if (is_same_tensor(r, dense_buffer)) {
|
||||
TORCH_CHECK(r_.is_contiguous(), "add: CUDA dense-sparse addition with a non-contiguous output tensor does not work; shout if you need it (see https://github.com/pytorch/pytorch/issues/1521 )");
|
||||
} else {
|
||||
if (!is_same_tensor(r, dense_buffer)) {
|
||||
r.resize_as_(dense);
|
||||
r.copy_(dense_buffer);
|
||||
}
|
||||
|
@ -220,26 +220,12 @@ class TestOptim(TestCase):
|
||||
self,
|
||||
constructor,
|
||||
scheduler_constructors=None,
|
||||
ignore_multidevice=False,
|
||||
constructor_accepts_maximize=False,
|
||||
constructor_accepts_foreach=False,
|
||||
atol=None,
|
||||
rtol=None,
|
||||
):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
|
||||
def make_two_arg_constructor(
|
||||
constructor, maximize: bool, foreach: bool
|
||||
):
|
||||
if constructor_accepts_maximize and constructor_accepts_foreach:
|
||||
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
|
||||
if constructor_accepts_maximize:
|
||||
return lambda weight, bias: constructor(weight, bias, maximize)
|
||||
if constructor_accepts_foreach:
|
||||
return lambda weight, bias: constructor(weight, bias, foreach)
|
||||
return constructor
|
||||
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
@ -272,7 +258,7 @@ class TestOptim(TestCase):
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
# Multi-GPU
|
||||
if not torch.cuda.device_count() > 1 or ignore_multidevice:
|
||||
if not torch.cuda.device_count() > 1:
|
||||
return
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(0),
|
||||
@ -326,13 +312,6 @@ class TestOptim(TestCase):
|
||||
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
|
||||
|
||||
def test_sgd(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
@ -405,42 +384,6 @@ class TestOptim(TestCase):
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
momentum=0.5,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
momentum=0.5,
|
||||
weight_decay=1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: SGD(
|
||||
[weight, bias],
|
||||
nesterov=True,
|
||||
lr=1e-3,
|
||||
momentum=0.5,
|
||||
weight_decay=1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sgd_sparse(self):
|
||||
@ -491,35 +434,6 @@ class TestOptim(TestCase):
|
||||
|
||||
|
||||
def test_adam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=0.1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adam(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
@ -634,36 +548,6 @@ class TestOptim(TestCase):
|
||||
))
|
||||
|
||||
def test_adamw(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=1,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
[weight, bias],
|
||||
@ -704,13 +588,6 @@ class TestOptim(TestCase):
|
||||
def test_adadelta(self):
|
||||
# Handles https://github.com/pytorch/pytorch/issues/69698
|
||||
self.rel_tol = 4e-3
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adadelta(
|
||||
[weight, bias], maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adadelta(
|
||||
self._build_params_dict(weight, bias, rho=0.95),
|
||||
@ -724,13 +601,7 @@ class TestOptim(TestCase):
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adadelta(
|
||||
[weight, bias], weight_decay=1, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_adadelta_complex(self):
|
||||
# Handles https://github.com/pytorch/pytorch/issues/110606
|
||||
@ -743,22 +614,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
def test_nadam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias], lr=1e-3, foreach=foreach
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=0.1,
|
||||
momentum_decay=6e-3,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
@ -771,17 +626,6 @@ class TestOptim(TestCase):
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
# NAdamW tests
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
lr=1e-3,
|
||||
weight_decay=0.1,
|
||||
momentum_decay=6e-3,
|
||||
decoupled_weight_decay=True,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: NAdam(
|
||||
[weight, bias],
|
||||
@ -819,24 +663,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
def test_adagrad(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
[weight, bias],
|
||||
lr=1e-1,
|
||||
initial_accumulator_value=0.1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adagrad(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
@ -894,41 +720,11 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
def test_adamax(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adamax(
|
||||
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Adamax(
|
||||
[weight, bias],
|
||||
lr=1e-1,
|
||||
weight_decay=1,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_complex_2d(Adamax)
|
||||
self._test_complex_2d(functools.partial(Adamax, foreach=True))
|
||||
|
||||
|
||||
def test_radam(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, foreach=foreach
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, foreach=foreach
|
||||
@ -940,12 +736,6 @@ class TestOptim(TestCase):
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
# RAdamW tests
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
|
||||
),
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, foreach: RAdam(
|
||||
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
|
||||
@ -983,13 +773,6 @@ class TestOptim(TestCase):
|
||||
|
||||
def test_rmsprop(self):
|
||||
for foreach in (False, True):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: RMSprop(
|
||||
[weight, bias], lr=1e-2, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
|
||||
self._test_complex_2d(
|
||||
lambda param: RMSprop(param, centered=True, foreach=foreach)
|
||||
@ -1016,13 +799,6 @@ class TestOptim(TestCase):
|
||||
|
||||
def test_asgd(self):
|
||||
for foreach in (False, True):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: ASGD(
|
||||
[weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
# Ref: https://github.com/pytorch/pytorch/issues/84560
|
||||
# self._test_complex_2d(optimizer)
|
||||
self._test_complex_optimizer(
|
||||
@ -1033,12 +809,12 @@ class TestOptim(TestCase):
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD(
|
||||
[params], maximize=True, weight_decay=0.9, foreach=foreach
|
||||
[params], maximize=True, weight_decay=0.1, foreach=foreach
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda params: ASGD(
|
||||
[params], maximize=False, weight_decay=0.9, foreach=foreach
|
||||
[params], maximize=False, weight_decay=0.1, foreach=foreach
|
||||
)
|
||||
)
|
||||
|
||||
@ -1046,17 +822,7 @@ class TestOptim(TestCase):
|
||||
@skipIfRocm
|
||||
@skipIfTorchDynamo()
|
||||
def test_rprop(self):
|
||||
is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(
|
||||
0
|
||||
) == (8, 6)
|
||||
for foreach in (False, True):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: Rprop(
|
||||
[weight, bias], lr=2e-4, maximize=maximize, foreach=foreach
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
|
||||
self._test_complex_optimizer(
|
||||
lambda param: Rprop([param], lr=0.001, foreach=foreach)
|
||||
@ -1068,17 +834,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
|
||||
def test_lbfgs(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: LBFGS([weight, bias]), ignore_multidevice=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: LBFGS(
|
||||
[weight, bias], line_search_fn="strong_wolfe"
|
||||
),
|
||||
ignore_multidevice=True,
|
||||
)
|
||||
|
||||
def test_lbfgs_returns_consistent_type(self):
|
||||
params = [torch.randn(10, 5), torch.randn(10)]
|
||||
opt1 = LBFGS(params, 0.01, tolerance_grad=math.inf)
|
||||
|
@ -56,6 +56,78 @@ class TestOptimRenewed(TestCase):
|
||||
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
|
||||
|
||||
|
||||
@parametrize("contiguous", [True, False])
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction(self, device, dtype, optim_info, contiguous):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
if contiguous:
|
||||
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
||||
else:
|
||||
weight = Parameter(torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0])
|
||||
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
|
||||
input = torch.randn(5, device=device, dtype=dtype)
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_cls.__name__ == "SparseAdam":
|
||||
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
|
||||
# which we know does NOT represent the expected use case!
|
||||
weight.grad = weight.grad.to_sparse()
|
||||
bias.grad = bias.grad.to_sparse()
|
||||
return loss
|
||||
|
||||
initial_value = closure().item()
|
||||
for _ in range(20):
|
||||
optimizer.step(closure)
|
||||
|
||||
if optim_input.kwargs.get("maximize", False):
|
||||
self.assertGreater(closure().item(), initial_value)
|
||||
else:
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device="cuda")
|
||||
for optim_input in optim_inputs:
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
||||
input = torch.randn(5, device="cuda:0", dtype=dtype)
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
|
||||
def closure():
|
||||
optimizer.zero_grad()
|
||||
loss = (weight.mv(input).cuda(1) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_cls.__name__ == "SparseAdam":
|
||||
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
|
||||
# which we know does NOT represent the expected use case!
|
||||
weight.grad = weight.grad.to_sparse()
|
||||
bias.grad = bias.grad.to_sparse()
|
||||
return loss
|
||||
|
||||
initial_value = closure().item()
|
||||
for _ in range(20):
|
||||
optimizer.step(closure)
|
||||
|
||||
if optim_input.kwargs.get("maximize", False):
|
||||
self.assertGreater(closure().item(), initial_value)
|
||||
else:
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
|
||||
"""
|
||||
Given a flag 'fused' or 'foreach', test for parity of optimizer state
|
||||
|
@ -1,6 +1,5 @@
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
@ -268,15 +267,15 @@ def optim_inputs_func_adadelta(device=None):
|
||||
params=None, kwargs={"lr": 0.01}, desc="non-default lr"
|
||||
), # TODO: Move out to testing in param_group?
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho"
|
||||
), # TODO: Move out to testing in param_group?
|
||||
]
|
||||
|
||||
@ -302,21 +301,22 @@ def optim_inputs_func_adagrad(device=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.9},
|
||||
kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
|
||||
desc="initial_accumulator_value",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.9},
|
||||
kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
|
||||
desc="lr_decay",
|
||||
), # TODO: Move out to testing in param_group?
|
||||
]
|
||||
@ -346,7 +346,7 @@ def optim_inputs_func_adam(device=None):
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "amsgrad": True, "capturable": True},
|
||||
kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
|
||||
desc="capturable, amsgrad",
|
||||
),
|
||||
OptimizerInput(
|
||||
@ -360,15 +360,15 @@ def optim_inputs_func_adam(device=None):
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9, "amsgrad": True}, desc="amsgrad"
|
||||
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
|
||||
),
|
||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
||||
|
||||
@ -452,13 +452,13 @@ def optim_inputs_func_adamax(device=None):
|
||||
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.001}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
||||
@ -495,11 +495,11 @@ def optim_inputs_func_asgd(device=None):
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
]
|
||||
@ -527,7 +527,7 @@ def optim_inputs_func_lbfgs(device=None):
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"tolerance_grad": math.inf}, desc="tolerance_grad"
|
||||
params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
@ -556,13 +556,13 @@ def optim_inputs_func_nadam(device=None):
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3},
|
||||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||||
desc="weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"weight_decay": 0.9,
|
||||
"weight_decay": 0.1,
|
||||
"momentum_decay": 6e-3,
|
||||
"decoupled_weight_decay": True,
|
||||
},
|
||||
@ -604,11 +604,11 @@ def optim_inputs_func_radam(device=None):
|
||||
OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "decoupled_weight_decay": True},
|
||||
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
|
||||
desc="decoupled_weight_decay",
|
||||
),
|
||||
]
|
||||
@ -645,22 +645,22 @@ def optim_inputs_func_rmsprop(device=None):
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.9}, desc="nonzero weight_decay"
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "centered": True},
|
||||
kwargs={"weight_decay": 0.1, "centered": True},
|
||||
desc="centered",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.9, "centered": True, "momentum": 0.1},
|
||||
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
|
||||
desc="momentum",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"weight_decay": 0.9,
|
||||
"weight_decay": 0.1,
|
||||
"centered": True,
|
||||
"momentum": 0.1,
|
||||
"maximize": True,
|
||||
@ -722,29 +722,27 @@ def optim_error_inputs_func_rprop(device, dtype):
|
||||
|
||||
def optim_inputs_func_sgd(device=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="Tensor lr"),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"lr": 1e-2, "momentum": 0.9}, desc="momentum"
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
||||
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"lr": 1e-2, "momentum": 0.9, "dampening": 0.5},
|
||||
kwargs={"momentum": 0.9, "dampening": 0.5},
|
||||
desc="dampening",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"lr": 1e-2, "momentum": 0.9, "weight_decay": 0.9},
|
||||
kwargs={"momentum": 0.9, "weight_decay": 0.1},
|
||||
desc="non-zero weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"lr": 1e-2, "momentum": 0.9, "nesterov": True, "weight_decay": 0.9},
|
||||
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
|
||||
desc="nesterov",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"lr": 1e-2, "weight_decay": 0.9, "maximize": True},
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
]
|
||||
@ -897,6 +895,20 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adadelta,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -920,7 +932,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -948,6 +960,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
supports_sparse_on=("cpu"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115607"
|
||||
@ -971,7 +1003,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -991,6 +1023,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1014,7 +1066,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1034,6 +1086,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115607"
|
||||
@ -1062,7 +1134,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1082,6 +1154,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1105,7 +1197,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1125,6 +1217,20 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See discrepancy in https://github.com/pytorch/pytorch/issues/115607"
|
||||
@ -1157,7 +1263,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1201,6 +1307,18 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_weight_decay",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip("LBFGS doesn't support multidevice"),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
@ -1209,6 +1327,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_nadam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1225,7 +1363,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1267,6 +1405,20 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
@ -1296,6 +1448,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -1329,7 +1501,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1356,6 +1528,26 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_rprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
active_if=lambda kwargs: not kwargs["contiguous"],
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -1379,7 +1571,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1407,6 +1599,20 @@ optim_db: List[OptimizerInfo] = [
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_sparse_on=("cpu", "cuda"),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
@ -1441,7 +1647,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116494"
|
||||
"No closure handling, https://github.com/pytorch/pytorch/issues/116494"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_deterministic",
|
||||
@ -1521,6 +1727,16 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_load_nontensor_step",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
|
Reference in New Issue
Block a user