mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127)
This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure. Lintrunner passed: ``` $ lintrunner test/test_cuda.py ok No lint issues. ``` Tests passed: ``` >python test_cuda.py -k test_graph_optims Ran 19 tests in 7.463s OK (skipped=9) >python test_cuda.py -k test_graph_scaling_fused_optimizers Ran 6 tests in 2.800s OK (skipped=3) ``` Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers. I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called: ``` {'foreach': False, 'maximize': True, 'weight_decay': 0} { 'foreach': True, 'maximize': True, 'weight_decay': 0} ``` I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned: ``` {'amsgrad': True} ``` Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5fb11cda4f
commit
cf35a591b9
@ -136,6 +136,8 @@ KERNEL_COUNT_OVERRIDES = {
|
||||
"test_sgd_momentum_foreach_cuda": 5,
|
||||
"test_sgd_weight_decay_maximize_cuda": 4,
|
||||
"test_sgd_weight_decay_maximize_cpu": 4,
|
||||
"test_sgd_weight_decay_cpu": 4,
|
||||
"test_sgd_weight_decay_cuda": 4,
|
||||
"test_sgd_momentum_weight_decay_foreach_cuda": 2,
|
||||
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
|
||||
"test_sgd_cuda": 4,
|
||||
|
@ -37,7 +37,11 @@ from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_optimizers import optim_db, optims
|
||||
from torch.testing._internal.common_optimizers import (
|
||||
_get_optim_inputs_including_global_cliquey_kwargs,
|
||||
optim_db,
|
||||
optims,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
gcIfJetson,
|
||||
@ -3200,111 +3204,6 @@ exit(2)
|
||||
for p_control, p_graphed in zip(params_control, params_graphed):
|
||||
self.assertEqual(p_control, p_graphed)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_graph_optims(self):
|
||||
# Needs generalization if we want to extend this test to non-Adam-like optimizers.
|
||||
cases = (
|
||||
[
|
||||
(
|
||||
optimizer_ctor,
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.8, 0.7),
|
||||
"foreach": foreach,
|
||||
"decoupled_weight_decay": decoupled_weight_decay,
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
)
|
||||
for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product(
|
||||
(
|
||||
torch.optim.NAdam,
|
||||
torch.optim.RAdam,
|
||||
),
|
||||
(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
(
|
||||
0.0,
|
||||
0.1,
|
||||
),
|
||||
)
|
||||
]
|
||||
+ [
|
||||
(
|
||||
torch.optim.Rprop,
|
||||
{"lr": 0.1, "foreach": foreach, "maximize": maximize},
|
||||
)
|
||||
for foreach, maximize in product(
|
||||
(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
)
|
||||
]
|
||||
+ [
|
||||
(
|
||||
optimizer_ctor,
|
||||
{
|
||||
"lr": 0.1,
|
||||
"betas": (0.8, 0.7),
|
||||
"foreach": foreach,
|
||||
"amsgrad": amsgrad,
|
||||
},
|
||||
)
|
||||
for optimizer_ctor, foreach, amsgrad in product(
|
||||
(torch.optim.Adam, torch.optim.AdamW),
|
||||
(False, True),
|
||||
(False, True),
|
||||
)
|
||||
]
|
||||
+ [
|
||||
(
|
||||
optimizer_ctor,
|
||||
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
|
||||
)
|
||||
for optimizer_ctor, amsgrad in product(
|
||||
(torch.optim.Adam, torch.optim.AdamW), (False, True)
|
||||
)
|
||||
]
|
||||
+ [
|
||||
(
|
||||
optimizer_ctor,
|
||||
{
|
||||
"lr": 0.1,
|
||||
"foreach": foreach,
|
||||
"maximize": maximize,
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
)
|
||||
for optimizer_ctor, foreach, maximize, weight_decay in product(
|
||||
(
|
||||
torch.optim.Adamax,
|
||||
torch.optim.ASGD,
|
||||
torch.optim.Adadelta,
|
||||
torch.optim.RMSprop,
|
||||
),
|
||||
(False, True),
|
||||
(False, True),
|
||||
(0, 0.1),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
for optimizer_ctor, kwargs in cases:
|
||||
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
|
||||
self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@ -3376,123 +3275,6 @@ exit(2)
|
||||
self.assertEqual(ref_p1, param1)
|
||||
self.assertEqual(ref_p2, param2)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
def test_graph_scaling_fused_optimizers(self):
|
||||
cases = [
|
||||
(
|
||||
optimizer_ctor,
|
||||
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
|
||||
)
|
||||
for optimizer_ctor, amsgrad in product(
|
||||
(torch.optim.Adam, torch.optim.AdamW), (False, True)
|
||||
)
|
||||
] + list(
|
||||
product(
|
||||
(torch.optim.SGD,),
|
||||
[
|
||||
{
|
||||
"lr": 0.1,
|
||||
"momentum": 0.0,
|
||||
"dampening": d,
|
||||
"weight_decay": w,
|
||||
"nesterov": n,
|
||||
"fused": True,
|
||||
}
|
||||
for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"lr": 0.1,
|
||||
"momentum": 0.5,
|
||||
"dampening": d,
|
||||
"weight_decay": w,
|
||||
"nesterov": n,
|
||||
"fused": True,
|
||||
}
|
||||
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps_warmup = 3
|
||||
steps_train = 2
|
||||
|
||||
for OptClass, kwargs in cases:
|
||||
has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW)
|
||||
for actually_do_graphs in (True, False) if has_capturable_arg else (True,):
|
||||
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
|
||||
params_control = [p.clone().requires_grad_() for p in params]
|
||||
params_graphed = [p.clone().requires_grad_() for p in params]
|
||||
|
||||
# `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
|
||||
grads = [
|
||||
[torch.randn_like(p) for p in params]
|
||||
for _ in range(steps_warmup + steps_train)
|
||||
]
|
||||
with torch.no_grad():
|
||||
grads_control = [[g.clone() for g in gs] for gs in grads]
|
||||
grads_graphed = [[g.clone() for g in gs] for gs in grads]
|
||||
|
||||
# Gradient Scaler
|
||||
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
|
||||
with torch.no_grad():
|
||||
scaler_for_control._lazy_init_scale_growth_tracker(
|
||||
torch.device("cuda")
|
||||
)
|
||||
|
||||
scaler_for_graphed = torch.cuda.amp.GradScaler()
|
||||
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
|
||||
with torch.no_grad():
|
||||
scaler_for_graphed._lazy_init_scale_growth_tracker(
|
||||
torch.device("cuda")
|
||||
)
|
||||
|
||||
# Control (capturable=False)
|
||||
if has_capturable_arg:
|
||||
kwargs["capturable"] = False
|
||||
opt = OptClass(params_control, **kwargs)
|
||||
|
||||
for i in range(steps_warmup + steps_train):
|
||||
for j, p in enumerate(params_control):
|
||||
p.grad = grads_control[i][j]
|
||||
scaler_for_control.step(opt)
|
||||
scaler_for_control.update()
|
||||
|
||||
# capturable=True
|
||||
if has_capturable_arg:
|
||||
kwargs["capturable"] = True
|
||||
opt = OptClass(params_graphed, **kwargs)
|
||||
|
||||
for i in range(steps_warmup):
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads_graphed[i][j]
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
if actually_do_graphs:
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
for i in range(steps_train):
|
||||
if actually_do_graphs:
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad.copy_(grads_graphed[i + steps_warmup][j])
|
||||
g.replay()
|
||||
else:
|
||||
# Passing capturable=True to the constructor and running without graphs should still be
|
||||
# numerically correct, even if it's not ideal for performance.
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads_graphed[i + steps_warmup][j]
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
for p_control, p_graphed in zip(params_control, params_graphed):
|
||||
self.assertEqual(p_control, p_graphed)
|
||||
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@ -4698,10 +4480,175 @@ class TestBlockStateAbsorption(TestCase):
|
||||
self.assertEqual(rc, "False", "Triton was imported when importing torch!")
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestCudaOptims(TestCase):
|
||||
# These tests will be instantiate with instantiate_device_type_tests
|
||||
# to apply the new OptimizerInfo structure.
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
|
||||
)
|
||||
@optims(
|
||||
[optim for optim in optim_db if optim.has_capturable_arg],
|
||||
dtypes=[torch.float32],
|
||||
)
|
||||
def test_graph_optims(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info, skip=("differentiable",)
|
||||
)
|
||||
|
||||
steps_warmup = 3
|
||||
steps_train = 2
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
|
||||
# lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
|
||||
# and torch.optim.adamw
|
||||
kwargs["lr"] = 0.1
|
||||
|
||||
for actually_do_graphs in (True, False):
|
||||
params = [
|
||||
torch.randn((i + 5, i + 5), device=device) for i in range(2)
|
||||
] + [torch.randn((), device=device)]
|
||||
params_control = [p.clone().requires_grad_() for p in params]
|
||||
params_graphed = [p.clone().requires_grad_() for p in params]
|
||||
|
||||
grads = [
|
||||
[torch.randn_like(p) for p in params]
|
||||
for _ in range(steps_warmup + steps_train)
|
||||
]
|
||||
|
||||
# Control (capturable=False)
|
||||
kwargs["capturable"] = False
|
||||
|
||||
opt = optim_cls(params_control, **kwargs)
|
||||
for i in range(steps_warmup + steps_train):
|
||||
for j, p in enumerate(params_control):
|
||||
p.grad = grads[i][j]
|
||||
opt.step()
|
||||
|
||||
# capturable=True
|
||||
kwargs["capturable"] = True
|
||||
opt = optim_cls(params_graphed, **kwargs)
|
||||
|
||||
for i in range(steps_warmup):
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads[i][j]
|
||||
opt.step()
|
||||
|
||||
if actually_do_graphs:
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
opt.step()
|
||||
|
||||
for i in range(steps_train):
|
||||
if actually_do_graphs:
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad.copy_(grads[i + steps_warmup][j])
|
||||
g.replay()
|
||||
else:
|
||||
# Passing capturable=True to the constructor and running without graphs should still be
|
||||
# numerically correct, even if it's not ideal for performance.
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads[i + steps_warmup][j]
|
||||
opt.step()
|
||||
|
||||
for p_control, p_graphed in zip(params_control, params_graphed):
|
||||
self.assertEqual(p_control, p_graphed)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@optims(
|
||||
[optim for optim in optim_db if "fused" in optim.supported_impls],
|
||||
dtypes=[torch.float32],
|
||||
)
|
||||
def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
steps_warmup = 3
|
||||
steps_train = 2
|
||||
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
|
||||
for optim_input in optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
kwargs["fused"] = True
|
||||
|
||||
for actually_do_graphs in (
|
||||
(True, False) if optim_info.has_capturable_arg else (True,)
|
||||
):
|
||||
params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
|
||||
params_control = [p.clone().requires_grad_() for p in params]
|
||||
params_graphed = [p.clone().requires_grad_() for p in params]
|
||||
|
||||
# `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
|
||||
grads = [
|
||||
[torch.randn_like(p) for p in params]
|
||||
for _ in range(steps_warmup + steps_train)
|
||||
]
|
||||
with torch.no_grad():
|
||||
grads_control = [[g.clone() for g in gs] for gs in grads]
|
||||
grads_graphed = [[g.clone() for g in gs] for gs in grads]
|
||||
|
||||
# Gradient Scaler
|
||||
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
|
||||
with torch.no_grad():
|
||||
scaler_for_control._lazy_init_scale_growth_tracker(device)
|
||||
|
||||
scaler_for_graphed = torch.cuda.amp.GradScaler()
|
||||
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
|
||||
with torch.no_grad():
|
||||
scaler_for_graphed._lazy_init_scale_growth_tracker(device)
|
||||
|
||||
# Control (capturable=False)
|
||||
if optim_info.has_capturable_arg:
|
||||
kwargs["capturable"] = False
|
||||
opt = optim_cls(params_control, **kwargs)
|
||||
|
||||
for i in range(steps_warmup + steps_train):
|
||||
for j, p in enumerate(params_control):
|
||||
p.grad = grads_control[i][j]
|
||||
scaler_for_control.step(opt)
|
||||
scaler_for_control.update()
|
||||
|
||||
# capturable=True
|
||||
if optim_info.has_capturable_arg:
|
||||
kwargs["capturable"] = True
|
||||
opt = optim_cls(params_graphed, **kwargs)
|
||||
|
||||
for i in range(steps_warmup):
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads_graphed[i][j]
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
if actually_do_graphs:
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
for i in range(steps_train):
|
||||
if actually_do_graphs:
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad.copy_(grads_graphed[i + steps_warmup][j])
|
||||
g.replay()
|
||||
else:
|
||||
# Passing capturable=True to the constructor and running without graphs should still be
|
||||
# numerically correct, even if it's not ideal for performance.
|
||||
for j, p in enumerate(params_graphed):
|
||||
p.grad = grads_graphed[i + steps_warmup][j]
|
||||
scaler_for_graphed.step(opt)
|
||||
scaler_for_graphed.update()
|
||||
|
||||
for p_control, p_graphed in zip(params_control, params_graphed):
|
||||
self.assertEqual(p_control, p_graphed)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(
|
||||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
|
@ -123,6 +123,8 @@ class OptimizerInfo:
|
||||
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
||||
# the optim supports passing in sparse gradients as well as dense grads
|
||||
supports_sparse: bool = False,
|
||||
# the optimizer constructor supports passing in capturable as a kwarg
|
||||
has_capturable_arg: bool = False,
|
||||
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
||||
only_supports_sparse_grads: bool = False,
|
||||
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
|
||||
@ -147,6 +149,7 @@ class OptimizerInfo:
|
||||
self.scheduler_inputs = scheduler_inputs
|
||||
self.supported_impls = supported_impls
|
||||
self.supports_sparse = supports_sparse
|
||||
self.has_capturable_arg = has_capturable_arg
|
||||
self.metadata_for_sparse = metadata_for_sparse
|
||||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||||
self.supports_complex = supports_complex
|
||||
@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
|
||||
@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None):
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
kwargs={"maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
@ -683,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None):
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||||
kwargs={
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
desc="weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||||
desc="weight_decay, momentum_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"weight_decay": 0.1,
|
||||
"momentum_decay": 6e-3,
|
||||
"decoupled_weight_decay": True,
|
||||
},
|
||||
desc="decoupled_weight_decay",
|
||||
@ -818,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"maximize": True,
|
||||
},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "centered": True},
|
||||
desc="centered",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"maximize": True,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
|
||||
@ -836,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None):
|
||||
"momentum": 0.1,
|
||||
"maximize": True,
|
||||
},
|
||||
desc="maximize",
|
||||
desc="maximize, centered, weight_decay, w/ momentum",
|
||||
),
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
@ -907,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "dampening": 0.5},
|
||||
@ -916,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "weight_decay": 0.1},
|
||||
desc="non-zero weight_decay",
|
||||
desc="weight_decay w/ momentum",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
|
||||
desc="nesterov",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -1097,6 +1130,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adadelta,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adadelta,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1232,6 +1266,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
has_capturable_arg=True,
|
||||
decorators=(
|
||||
# Expected floating point error between fused and compiled forloop
|
||||
DecorateInfo(
|
||||
@ -1298,6 +1333,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adamax,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1348,6 +1384,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
has_capturable_arg=True,
|
||||
decorators=(
|
||||
# Expected error between compiled forloop and fused optimizers
|
||||
DecorateInfo(
|
||||
@ -1414,6 +1451,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_asgd,
|
||||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1506,6 +1544,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_nadam,
|
||||
optim_error_inputs_func=optim_error_inputs_func_nadam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1561,6 +1600,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_radam,
|
||||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1606,6 +1646,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_rmsprop,
|
||||
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1655,6 +1696,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_rprop,
|
||||
optim_error_inputs_func=optim_error_inputs_func_rprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117
|
||||
|
Reference in New Issue
Block a user