mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move sparse tests to TestOptimRenewed (#123146)
This is the last of the old TestOptim! With this change, everything will be migrated to use OptimizerInfo. Our sparse support is...well, sparse, and the tests try to best encapsulate which configs actually work. Note that support_sparse is actually just supports sparse grads...we don't test sparse params. 1. This PR fixes a bug in Adagrad multi_tensor with maximize by passing the correct value of maximize (vs False everytime) when sparse values are present. 2. This PR does improve coverage. There used to only be 2 configs each, and now we have the following configs for: Adagrad: ``` python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Adagrad /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( {'maximize': True, 'lr': 0.1} {'initial_accumulator_value': 0.1, 'lr': 0.1} <--- this and above are CPU .{'foreach': False, 'lr': 0.1} {'foreach': True, 'lr': 0.1} {'maximize': True, 'foreach': False, 'lr': 0.1} {'maximize': True, 'foreach': True, 'lr': 0.1} {'initial_accumulator_value': 0.1, 'foreach': False, 'lr': 0.1} {'initial_accumulator_value': 0.1, 'foreach': True, 'lr': 0.1} . ---------------------------------------------------------------------- Ran 2 tests in 227.744s OK ``` SGD ``` (pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_SGD /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( {'dampening': 0.5, 'lr': 0.0048} .{'foreach': False, 'lr': 0.0048} {'foreach': True, 'lr': 0.0048} {'dampening': 0.5, 'foreach': False, 'lr': 0.0048} {'dampening': 0.5, 'foreach': True, 'lr': 0.0048} . ---------------------------------------------------------------------- Ran 2 tests in 112.801s OK ``` SparseAdam ``` (pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Sparse /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( {'maximize': True, 'lr': 0.04} .{'maximize': True, 'lr': 0.04} . ---------------------------------------------------------------------- Ran 2 tests in 35.113s OK ``` Fixes #103322. A side quest in this migration was to re-enable and track dynamo issues as they trigger on the optim tests, which will be complete from this PR. New tests may add more things to track in dynamo, but there is now an established system for doing so, and dynamo is either enabled or a bug is tracked for every migrated test in TestOptimRenewed. Next steps: Remove the hyperparameter constraints in common_optimizer.py defined by metadata_for_sparse (other than LR, which seems handpicked for the tests to actually pass). Doing this requires adding more sparse functionality. Add more tests! Maybe add more optimizers! Pull Request resolved: https://github.com/pytorch/pytorch/pull/123146 Approved by: https://github.com/albanD ghstack dependencies: #123134, #123139
This commit is contained in:
committed by
PyTorch MergeBot
parent
f2838c99a0
commit
d7fe0603a1
@ -1,13 +1,9 @@
|
||||
# Owner(s): ["module: optimizer"]
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import (
|
||||
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
|
||||
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD
|
||||
)
|
||||
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
load_tests,
|
||||
@ -22,157 +18,6 @@ from torch.testing._internal.common_utils import (
|
||||
load_tests = load_tests
|
||||
|
||||
|
||||
def rosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
x, y = tensor
|
||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
x, y = tensor
|
||||
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
||||
|
||||
@skipIfTorchDynamo("This is a TEMPORARY stopgap, see https://github.com/pytorch/pytorch/issues/103322")
|
||||
class TestOptim(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
def _test_rosenbrock_sparse(
|
||||
self,
|
||||
constructor,
|
||||
scheduler_constructors=None,
|
||||
sparse_only=False,
|
||||
maximize=False,
|
||||
multi_tensor=False
|
||||
):
|
||||
if scheduler_constructors is None:
|
||||
scheduler_constructors = []
|
||||
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
|
||||
if multi_tensor:
|
||||
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=torch.float64)]
|
||||
else:
|
||||
params_t = [torch.tensor([1.5, 1.5])]
|
||||
|
||||
params = [Parameter(param_t) for param_t in params_t]
|
||||
optimizer = constructor(params)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
if not sparse_only:
|
||||
params_c = [Parameter(param_t.clone()) for param_t in params_t]
|
||||
optimizer_c = constructor(params_c)
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
with torch.no_grad():
|
||||
initial_dist = sum([param.dist(solution) for param in params])
|
||||
|
||||
def get_grad(param, sparse_grad):
|
||||
grad = drosenbrock(param)
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
|
||||
# Depending on w, provide only the x or y gradient
|
||||
if sparse_grad:
|
||||
if w:
|
||||
i = torch.LongTensor([[0, 0]])
|
||||
x = grad[0]
|
||||
v = torch.tensor([x / 4.0, x - x / 4.0])
|
||||
else:
|
||||
i = torch.LongTensor([[1, 1]])
|
||||
y = grad[1]
|
||||
v = torch.tensor([y - y / 4.0, y / 4.0])
|
||||
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
|
||||
else:
|
||||
if w:
|
||||
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
|
||||
else:
|
||||
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
|
||||
return grad_out
|
||||
|
||||
def eval(params, sparse_grad, w):
|
||||
optimizer.zero_grad()
|
||||
if multi_tensor:
|
||||
loss = sum(rosenbrock(param) for param in params)
|
||||
else:
|
||||
loss = rosenbrock(params[0])
|
||||
loss.backward()
|
||||
|
||||
grads_out = [get_grad(param, sparse_grad) for param in params]
|
||||
with torch.no_grad():
|
||||
params[0].grad = grads_out[0]
|
||||
if multi_tensor:
|
||||
params[1].grad = grads_out[1].to(dtype=torch.float64)
|
||||
return loss
|
||||
|
||||
for i in range(2000):
|
||||
# Do cyclic coordinate descent
|
||||
w = i % 2
|
||||
optimizer.step(functools.partial(eval, params, True, w))
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(rosenbrock(params[0]))
|
||||
else:
|
||||
scheduler.step()
|
||||
if not sparse_only:
|
||||
optimizer_c.step(functools.partial(eval, params_c, False, w))
|
||||
# Tolerance is increased due to floating point error from different
|
||||
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
|
||||
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
|
||||
|
||||
if not maximize:
|
||||
self.assertLessEqual(
|
||||
sum([param.dist(solution) for param in params]),
|
||||
initial_dist
|
||||
)
|
||||
else:
|
||||
self.assertGreaterEqual(
|
||||
sum([rosenbrock(param) for param in params]),
|
||||
sum([rosenbrock(param_t) for param_t in params_t]),
|
||||
)
|
||||
|
||||
|
||||
def test_sgd_sparse(self):
|
||||
for foreach in (False, True):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: SGD(params, lr=4.8e-3, foreach=foreach),
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: SGD(params, lr=0.0048, foreach=foreach),
|
||||
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_adam(self):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: SparseAdam(params, lr=4e-2), [], True
|
||||
)
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: SparseAdam(params, lr=4e-2, maximize=True),
|
||||
scheduler_constructors=[],
|
||||
sparse_only=True,
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
|
||||
def test_adagrad_sparse(self):
|
||||
for foreach in (False, True):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: Adagrad(params, lr=1e-1, foreach=foreach),
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: Adagrad(params, lr=0.1, foreach=foreach),
|
||||
scheduler_constructors=[
|
||||
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
||||
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
||||
],
|
||||
multi_tensor=foreach,
|
||||
)
|
||||
|
||||
|
||||
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
||||
# Ignored is the list of values in `opt_differentiable_state`, we do this
|
||||
# for `gradcheck` to correctly track the state tensors as function inputs
|
||||
|
@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
import torch
|
||||
from torch.optim import Optimizer, SGD
|
||||
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
||||
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
|
||||
from optim.test_optim import TestDifferentiableOptimizer # noqa: F401
|
||||
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
||||
from optim.test_swa_utils import TestSWAUtils # noqa: F401
|
||||
from torch.nn import Parameter
|
||||
@ -34,6 +34,12 @@ def rosenbrock(tensor):
|
||||
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
|
||||
|
||||
|
||||
def drosenbrock(tensor):
|
||||
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
|
||||
x, y = tensor
|
||||
return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
|
||||
|
||||
|
||||
@markDynamoStrictTest
|
||||
class TestOptimRenewed(TestCase):
|
||||
|
||||
@ -249,6 +255,129 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertEqual(bias, bias_c)
|
||||
|
||||
|
||||
@parametrize("with_lrsched", [True, False])
|
||||
@optims([o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads], dtypes=[torch.float64])
|
||||
def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
# Fused impls do not support sparse gradients
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info, skip=("differentiable", "fused"))
|
||||
kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
|
||||
|
||||
if with_lrsched and len(schedulers_constructors) == 0:
|
||||
return
|
||||
|
||||
supported_inputs = []
|
||||
if len(kwarg_updates) != 0:
|
||||
seen = set()
|
||||
for i in all_optim_inputs:
|
||||
for k in kwarg_updates:
|
||||
if k in i.kwargs:
|
||||
del i.kwargs[k]
|
||||
hashable_kwargs = tuple(sorted(i.kwargs.items()))
|
||||
if len(i.kwargs) > 0 and hashable_kwargs not in seen:
|
||||
supported_inputs.append(i)
|
||||
seen.add(hashable_kwargs)
|
||||
if "lr" in kwarg_updates:
|
||||
i.kwargs["lr"] = kwarg_updates["lr"]
|
||||
else:
|
||||
supported_inputs = all_optim_inputs
|
||||
|
||||
for optim_input in supported_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
multi_tensor = kwargs.get("foreach", False)
|
||||
|
||||
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
|
||||
if multi_tensor:
|
||||
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=dtype)]
|
||||
else:
|
||||
params_t = [torch.tensor([1.5, 1.5])]
|
||||
|
||||
params = [Parameter(param_t) for param_t in params_t]
|
||||
optimizer = optim_cls(params, **kwargs)
|
||||
schedulers = [s(optimizer) for s in (schedulers_constructors if with_lrsched else [])]
|
||||
|
||||
if not optim_info.only_supports_sparse_grads:
|
||||
params_c = [Parameter(param_t.clone()) for param_t in params_t]
|
||||
optimizer_c = optim_cls(params_c, **kwargs)
|
||||
schedulers_c = [s(optimizer_c) for s in (schedulers_constructors if with_lrsched else [])]
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
with torch.no_grad():
|
||||
initial_dist = sum([param.dist(solution) for param in params])
|
||||
|
||||
def get_grad(param, sparse_grad, w):
|
||||
grad = drosenbrock(param)
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
|
||||
# Depending on w, provide only the x or y gradient
|
||||
if sparse_grad:
|
||||
if w:
|
||||
i = torch.tensor([[0, 0]], dtype=torch.int64)
|
||||
x = grad[0]
|
||||
v = torch.tensor([x / 4.0, x - x / 4.0])
|
||||
else:
|
||||
i = torch.tensor([[1, 1]], dtype=torch.int64)
|
||||
y = grad[1]
|
||||
v = torch.tensor([y - y / 4.0, y / 4.0])
|
||||
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
|
||||
else:
|
||||
if w:
|
||||
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
|
||||
else:
|
||||
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
|
||||
return grad_out
|
||||
|
||||
def eval(params, sparse_grad, w):
|
||||
optimizer.zero_grad()
|
||||
if multi_tensor:
|
||||
loss = sum(rosenbrock(param) for param in params)
|
||||
else:
|
||||
loss = rosenbrock(params[0])
|
||||
loss.backward()
|
||||
|
||||
grads_out = [get_grad(param, sparse_grad, w) for param in params]
|
||||
with torch.no_grad():
|
||||
params[0].grad = grads_out[0]
|
||||
if multi_tensor:
|
||||
params[1].grad = grads_out[1].to(dtype=dtype)
|
||||
return loss
|
||||
|
||||
for i in range(1800):
|
||||
# Do cyclic coordinate descent
|
||||
w = i % 2
|
||||
optimizer.step(functools.partial(eval, params, True, w))
|
||||
for scheduler in schedulers:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(rosenbrock(params[0]))
|
||||
else:
|
||||
scheduler.step()
|
||||
if not optim_info.only_supports_sparse_grads:
|
||||
optimizer_c.step(functools.partial(eval, params_c, False, w))
|
||||
for scheduler in schedulers_c:
|
||||
if isinstance(scheduler, ReduceLROnPlateau):
|
||||
scheduler.step(rosenbrock(params_c[0]))
|
||||
else:
|
||||
scheduler.step()
|
||||
# Tolerance is increased due to floating point error from different
|
||||
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
|
||||
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
|
||||
|
||||
if not kwargs.get("maximize", False):
|
||||
self.assertLessEqual(
|
||||
sum([param.dist(solution) for param in params]),
|
||||
initial_dist
|
||||
)
|
||||
else:
|
||||
self.assertGreaterEqual(
|
||||
sum([rosenbrock(param) for param in params]),
|
||||
sum([rosenbrock(param_t) for param_t in params_t]),
|
||||
)
|
||||
|
||||
|
||||
@skipMPS
|
||||
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
|
||||
def test_complex(self, device, dtype, optim_info):
|
||||
|
@ -338,7 +338,7 @@ def _multi_tensor_adagrad(
|
||||
lr_decay=lr_decay,
|
||||
eps=eps,
|
||||
has_sparse_grad=True,
|
||||
maximize=False,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
@ -124,10 +124,14 @@ class OptimizerInfo:
|
||||
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
|
||||
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
|
||||
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
||||
# the devices on which the optim supports sparse tensors for params and grads, see SGD
|
||||
supports_sparse_on: Tuple[str] = (),
|
||||
# the optim supports passing in sparse gradients as well as dense grads
|
||||
supports_sparse: bool = False,
|
||||
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
||||
only_supports_sparse_grads: bool = False,
|
||||
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
|
||||
# with especially tuned hyperparameters. These only apply if the optimizer supports
|
||||
# sparse parameters or grads.
|
||||
metadata_for_sparse=({}, []),
|
||||
# the optim supports complex parameters
|
||||
supports_complex: bool = True,
|
||||
# whether the optimizer.step() function requires a closure to be passed
|
||||
@ -144,7 +148,8 @@ class OptimizerInfo:
|
||||
self.optim_inputs_func = optim_inputs_func
|
||||
self.scheduler_inputs = scheduler_inputs
|
||||
self.supported_impls = supported_impls
|
||||
self.supports_sparse_on = supports_sparse_on
|
||||
self.supports_sparse = supports_sparse
|
||||
self.metadata_for_sparse = metadata_for_sparse
|
||||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||||
self.supports_complex = supports_complex
|
||||
self.step_requires_closure = step_requires_closure
|
||||
@ -1135,7 +1140,14 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adagrad,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adagrad,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
supports_sparse_on=("cpu"),
|
||||
supports_sparse=True,
|
||||
metadata_for_sparse=(
|
||||
{"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
|
||||
[
|
||||
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
||||
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
||||
],
|
||||
),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1184,6 +1196,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Fails assertion of params close to params_c at all, see #123147"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_rosenbrock_sparse",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -2077,7 +2096,17 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_sgd,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_sparse_on=("cpu", "cuda"),
|
||||
supports_sparse=True,
|
||||
metadata_for_sparse=(
|
||||
{
|
||||
"lr": 4.8e-3,
|
||||
"maximize": False,
|
||||
"momentum": 0,
|
||||
"nesterov": False,
|
||||
"weight_decay": 0,
|
||||
},
|
||||
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
||||
),
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
@ -2118,6 +2147,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_set_default_dtype_works_with_foreach",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Fails assertion of params close to params_c at all, see #123147"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_rosenbrock_sparse",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||||
@ -2203,6 +2239,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_sparseadam,
|
||||
supported_impls=(),
|
||||
only_supports_sparse_grads=True,
|
||||
metadata_for_sparse=({"lr": 4e-2}, []),
|
||||
supports_complex=False, # Missing complex support, see #118153
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
|
Reference in New Issue
Block a user