[BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)

Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group`

```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
...
----------------------------------------------------------------------
Ran 3 tests in 0.023s

OK
```

Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised.

The new test_errors does not run slower than before, overhead is still king:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
..........................
----------------------------------------------------------------------
Ran 26 tests in 10.337s

OK
```

Compared to test_errors BEFORE my commit :p
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
.............sssssssssssss
----------------------------------------------------------------------
Ran 26 tests in 11.980s

OK (skipped=13)
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116315
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2023-12-26 18:36:12 +00:00
committed by PyTorch MergeBot
parent 8abeacda6f
commit 44b98c09ca
4 changed files with 286 additions and 303 deletions

View File

@ -906,14 +906,6 @@ class TestOptim(TestCase):
sparse_only=True,
maximize=True,
)
import warnings
with warnings.catch_warnings(record=True) as ws:
SparseAdam(torch.zeros(3))
self.assertEqual(len(ws), 1)
for warning in ws:
self.assertEqual(len(warning.message.args), 1)
self.assertRegex(warning.message.args[0],
"Passing in a raw Tensor as ``params`` to SparseAdam ")
# ROCm precision is too low to pass this test
def test_adadelta(self):
@ -1438,20 +1430,6 @@ class TestOptim(TestCase):
self.assertEqual(type(res1), type(res2))
def test_duplicate_params_in_one_param_group(self):
param = Parameter(torch.randn(1))
with self.assertWarnsOnceRegex(UserWarning, '.*a parameter group with duplicate parameters.*'):
Adamax([param, param], lr=0.01)
def test_duplicate_params_across_param_groups(self):
param = Parameter(torch.randn(1))
self.assertRaisesRegex(
ValueError,
'some parameters appear in more than one parameter group',
lambda: Adadelta([{'params': param}, {'params': param}])
)
def test_fused_optimizer_does_not_step_if_foundinf(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is required.")
@ -1621,14 +1599,6 @@ class TestOptim(TestCase):
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
def test_fused_optimizer_raises(self):
if not torch.cuda.is_available():
self.skipTest("Requires CUDA devices")
for optimizer_ctor in (Adam, AdamW):
with self.assertRaisesRegex(RuntimeError, "`fused` and `foreach` cannot be `True` together."):
optimizer_ctor([torch.empty((), device="cuda")], foreach=True, fused=True)
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:

View File

@ -26,7 +26,6 @@ class TestOptimRenewed(TestCase):
self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs))
@onlyCPU
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
def test_errors(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
@ -36,12 +35,20 @@ class TestOptimRenewed(TestCase):
optim_input = error_input.optimizer_error_input
params, kwargs = optim_input.params, optim_input.kwargs
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
else:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
optim = optim_cls(params, **kwargs)
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")

View File

@ -260,6 +260,7 @@ class Optimizer:
"is deprecated. In the future, this will raise an error. "
"Please wrap your Tensor in an iterable instead."),
FutureWarning)
params = [params]
else:
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +

View File

@ -190,6 +190,43 @@ class optims(_TestParametrizer):
raise ex
# Helper function for generating error inputs for all optimizers, used below.
def get_error_inputs_for_all_optims(device, dtype):
if str(device) == "cpu":
sample_param = Parameter(torch.randn(1, device=device, dtype=dtype))
return [
ErrorOptimizerInput(
OptimizerInput(
params=sample_param,
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
ErrorOptimizerInput(
OptimizerInput(
params=[sample_param, sample_param],
kwargs={},
desc="a param group cannot have duplicate parameters",
),
error_type=UserWarning,
error_regex=".*a parameter group with duplicate parameters.*",
),
ErrorOptimizerInput(
OptimizerInput(
params=[{"params": sample_param}, {"params": sample_param}],
kwargs={},
desc="duplicate parameters should not occur across param groups either",
),
error_type=ValueError,
error_regex="some parameters appear in more than one parameter group",
),
]
else:
return []
# ------------------------------------------------------------------------------------------
# NOTE: [optimizer kwarg categories]
# We categorize optimizer kwargs as 3 types:
@ -237,26 +274,20 @@ def optim_inputs_func_adadelta(device=None):
def optim_error_inputs_func_adadelta(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, rho=1.1),
desc="rho should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, rho=1.1),
desc="rho should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid rho value: 1.1",
),
error_type=ValueError,
error_regex="Invalid rho value: 1.1",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_adagrad(device=None):
@ -284,26 +315,20 @@ def optim_inputs_func_adagrad(device=None):
def optim_error_inputs_func_adagrad(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, lr_decay=-0.5),
desc="lr_decay must be bigger than 0",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, lr_decay=-0.5),
desc="lr_decay must be bigger than 0",
),
error_type=ValueError,
error_regex="Invalid lr_decay value: -0.5",
),
error_type=ValueError,
error_regex="Invalid lr_decay value: -0.5",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
@ -341,44 +366,60 @@ def optim_inputs_func_adam(device=None):
def optim_error_inputs_func_adam(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
desc="lr as Tensor doesn't work with foreach & not capturable",
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
desc="lr as Tensor doesn't work with foreach & not capturable",
),
error_type=ValueError,
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
),
error_type=ValueError,
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
]
if str(device) == "cuda":
sample_tensor = torch.empty((), device=device, dtype=dtype)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=[sample_tensor],
kwargs={"foreach": True, "fused": True},
desc="`fused` and `foreach` cannot be `True` together",
),
error_type=RuntimeError,
error_regex="`fused` and `foreach` cannot be `True` together",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
ErrorOptimizerInput(
OptimizerInput(
params=[sample_tensor],
kwargs={"fused": True, "differentiable": True},
desc="`fused` does not support `differentiable`",
),
error_type=RuntimeError,
error_regex="`fused` does not support `differentiable`",
),
]
return error_inputs
def optim_inputs_func_adamax(device=None):
@ -397,26 +438,20 @@ def optim_inputs_func_adamax(device=None):
def optim_error_inputs_func_adamax(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
desc="beta2 should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
desc="beta2 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 1: 1.0",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 1: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_adamw(device=None):
@ -444,26 +479,20 @@ def optim_inputs_func_asgd(device=None):
def optim_error_inputs_func_asgd(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-0.5),
desc="weight_decay should > 0",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-0.5),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -0.5",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -0.5",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_lbfgs(device=None):
@ -482,17 +511,7 @@ def optim_inputs_func_lbfgs(device=None):
def optim_error_inputs_func_lbfgs(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
return get_error_inputs_for_all_optims(device, dtype)
# Weird story bro, NAdam and RAdam do not have maximize.
@ -526,35 +545,29 @@ def optim_inputs_func_nadam(device=None):
def optim_error_inputs_func_nadam(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum_decay=-0.2),
desc="momentum_decay should > 0",
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum_decay=-0.2),
desc="momentum_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid momentum_decay value: -0.2",
),
error_type=ValueError,
error_regex="Invalid momentum_decay value: -0.2",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
# Weird story bro, NAdam and RAdam do not have maximize.
@ -575,35 +588,29 @@ def optim_inputs_func_radam(device=None):
def optim_error_inputs_func_radam(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_rmsprop(device=None):
@ -637,26 +644,20 @@ def optim_inputs_func_rmsprop(device=None):
def optim_error_inputs_func_rmsprop(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-1.0),
desc="momentum should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-1.0),
desc="momentum should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid momentum value: -1.0",
),
error_type=ValueError,
error_regex="Invalid momentum value: -1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_rprop(device=None):
@ -676,26 +677,20 @@ def optim_inputs_func_rprop(device=None):
def optim_error_inputs_func_rprop(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
desc="0 < eta1 < 1 < eta2",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
desc="0 < eta1 < 1 < eta2",
),
error_type=ValueError,
error_regex="Invalid eta values: 1.0, 0.5",
),
error_type=ValueError,
error_regex="Invalid eta values: 1.0, 0.5",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_sgd(device=None):
@ -728,26 +723,20 @@ def optim_inputs_func_sgd(device=None):
def optim_error_inputs_func_sgd(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-0.5),
desc="momentum should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-0.5),
desc="momentum should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid momentum value: -0.5",
),
error_type=ValueError,
error_regex="Invalid momentum value: -0.5",
),
ErrorOptimizerInput(
OptimizerInput(
params=Parameter(torch.randn(1, device=device, dtype=dtype)),
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
]
]
return error_inputs
def optim_inputs_func_sparseadam(device=None):
@ -761,45 +750,61 @@ def optim_inputs_func_sparseadam(device=None):
def optim_error_inputs_func_sparseadam(device, dtype):
return [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if str(device) == "cpu":
# SparseAdam raises a warning and not an error for the first entry. We
# update it here:
error_inputs[0].error_type = FutureWarning
error_inputs[
0
].error_regex = "Passing in a raw Tensor as ``params`` to SparseAdam"
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=[
torch.zeros(3, layout=torch.sparse_coo, device=device, dtype=dtype)
],
kwargs={},
desc="dense params required",
ErrorOptimizerInput(
OptimizerInput(
params=[
torch.zeros(
3, layout=torch.sparse_coo, device=device, dtype=dtype
)
],
kwargs={},
desc="dense params required",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
ErrorOptimizerInput(
OptimizerInput(
params=[
{
"params": [
torch.zeros(
3, layout=torch.sparse_coo, device=device, dtype=dtype
)
]
}
],
kwargs={},
desc="dense params required in param_groups",
ErrorOptimizerInput(
OptimizerInput(
params=[
{
"params": [
torch.zeros(
3,
layout=torch.sparse_coo,
device=device,
dtype=dtype,
)
]
}
],
kwargs={},
desc="dense params required in param_groups",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
]
]
return error_inputs
# Database of OptimizerInfo entries in alphabetical order.