[optim] Rectify capturable testing and fix bugs! (#118326)

This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
This commit is contained in:
Jane Xu
2024-02-02 08:05:26 -08:00
committed by PyTorch MergeBot
parent 8b00e5aa12
commit b5ba80828f
12 changed files with 332 additions and 123 deletions

View File

@ -44,6 +44,10 @@ def get_optimizer_step(opt, closure=None):
def make_test(optim_cls, closure=None, **kwargs):
# Remove this conditional when #118230 is fixed
if optim_cls.__name__ == "Adamax":
kwargs["foreach"] = True
opt = optim_cls(model.parameters(), **kwargs)
def test_fn(self):

View File

@ -74,10 +74,10 @@ KERNEL_COUNTS = {
SGD: KernelCounts(multitensor=2, singletensor=8),
RAdam: KernelCounts(
multitensor=2, singletensor=None
), # Single tensor eager needs to be refactored to enable tracing
), # Single tensor eager needs to be refactored to enable tracing (#118230)
Adamax: KernelCounts(
multitensor=2, singletensor=None
), # Single tensor eager needs to be refactored to enable tracing
), # Single tensor eager needs to be refactored to enable tracing (#117836)
}
@ -87,12 +87,9 @@ def build_compiled_opt_kwarg_db():
if optim_info.optim_cls not in KERNEL_COUNTS:
continue
for optim_inputs in optim_info.optim_inputs_func():
for device in ["cpu", "cuda"]:
for optim_inputs in optim_info.optim_inputs_func(device):
for foreach in [True, False]:
if device == "cpu" and "capturable" in optim_inputs.kwargs:
continue
kwargs = dict(optim_inputs.kwargs)
name = (
f"test_{optim_info.optim_cls.__name__.lower()}"
@ -107,7 +104,21 @@ def build_compiled_opt_kwarg_db():
name += f"_{device}"
# Eager for-loop impl doesn't support capturable ASGD
if name == "test_asgd_capturable_cuda":
if name in [
"test_asgd_capturable_cuda",
"test_asgd_maximize_capturable_cuda",
"test_asgd_weight_decay_capturable_cuda",
"test_asgd_weight_decay_maximize_capturable_cuda",
]:
continue
# Adam(W) capturable cudagraphs manager is unexpectedly None, #119026
if name in [
"test_adam_amsgrad_capturable_cuda",
"test_adam_foreach_amsgrad_capturable_cuda",
"test_adamw_amsgrad_capturable_cuda",
"test_adamw_foreach_amsgrad_capturable_cuda",
]:
continue
kwargs["foreach"] = foreach

View File

@ -583,9 +583,9 @@ class TestLazyModules(TestCase):
@suppress_warnings
def test_optimizer_pass(self):
# Add Adamax and RAdam when #118230 and #117836 are complete
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
torch.optim.AdamW, torch.optim.Adamax,
torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
torch.optim.RMSprop, torch.optim.LBFGS]
def run_step(module, optim):

View File

@ -19,10 +19,8 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
# Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
# Radam does not support capturable single tensor
def _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs):
if optim_info.only_supports_capturable_on_foreach and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
kwargs["capturable"] = False
@markDynamoStrictTest
@ -71,6 +69,9 @@ class TestOptimRenewed(TestCase):
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
_force_capturable_False_for_unsupported_single_tensor(optim_info, optim_input.kwargs)
if contiguous:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
@ -79,8 +80,6 @@ class TestOptimRenewed(TestCase):
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
input = torch.randn(5, device=device, dtype=dtype)
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
def closure():
@ -109,13 +108,14 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device="cuda")
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
@ -148,13 +148,19 @@ class TestOptimRenewed(TestCase):
def test_complex(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
# Also skip fused, since our fused kernels do not support complex
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable", "fused"))
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
# Last param is intentionally real to test that we can mix real and complex
complex_params = [
torch.randn(10, 5, dtype=dtype, requires_grad=True),
torch.randn(10, dtype=dtype, requires_grad=True),
torch.randn(10, 5, dtype=torch.float32, requires_grad=True),
torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
torch.randn(10, device=device, dtype=dtype, requires_grad=True),
torch.randn(10, 5, device=device, dtype=torch.float32, requires_grad=True),
]
real_params = [
(
@ -164,8 +170,7 @@ class TestOptimRenewed(TestCase):
)
for param in complex_params
]
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
real_steps = []
@ -234,14 +239,13 @@ class TestOptimRenewed(TestCase):
for optim_input in optim_inputs:
updated_params, state = [], []
kwargs = deepcopy(optim_input.kwargs)
if (kwargs.get("capturable", False) and str(device) == "cpu"):
if kwargs.get("capturable", False) and str(device) == "cpu":
# capturable is not supported on CPU
continue
for flag_value in (False, True):
kwargs[flag] = flag_value
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
input = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
@ -344,7 +348,10 @@ class TestOptimRenewed(TestCase):
for optim_input in optim_inputs:
updated_params, state = [], []
kwargs = deepcopy(optim_input.kwargs)
if kwargs.get("capturable", False) and str(device) == "cpu":
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
if kwargs.get("capturable", False) and str(device) == "cpu" :
# capturable is not supported on CPU
continue
for use_impl in (False, True):
@ -357,8 +364,6 @@ class TestOptimRenewed(TestCase):
p_clone.grad = p.grad.clone().detach()
params_clone.append(p_clone)
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
optimizer = optim_cls(params_clone, **kwargs)
for _ in range(kIterations):
optimizer.step()
@ -393,6 +398,7 @@ class TestOptimRenewed(TestCase):
# default dtype is higher prec float64
old_default_dtype = torch.get_default_dtype()
for default_dtype in [torch.float64, torch.float16]:
try:
torch.set_default_dtype(default_dtype)
self._test_derived_optimizers(
device,
@ -402,6 +408,7 @@ class TestOptimRenewed(TestCase):
reduced_precision=default_dtype == torch.float16,
assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
)
finally:
torch.set_default_dtype(old_default_dtype)
@ -431,8 +438,7 @@ class TestOptimRenewed(TestCase):
for flag_value in (False, True):
kwargs["foreach"] = flag_value
# https://github.com/pytorch/pytorch/issues/118230
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
@ -539,6 +545,11 @@ class TestOptimRenewed(TestCase):
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
weight_kwargs = optim_input.kwargs
bias_kwargs = deepcopy(optim_input.kwargs)
bias_kwargs["weight_decay"] = 0.0
@ -575,6 +586,11 @@ class TestOptimRenewed(TestCase):
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
optim_input.kwargs["lr"] = 1e-3
@ -630,10 +646,12 @@ class TestOptimRenewed(TestCase):
return torch.tensor([1], device=device, dtype=dtype)
for optim_input in all_optim_inputs:
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
optimizer.step(closure)
self.assertEqual(old_params, params)
@optims(optim_db, dtypes=[torch.float32])
@ -648,7 +666,10 @@ class TestOptimRenewed(TestCase):
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
and not kwargs.get("foreach", False)):
continue
# params will decay even if grads are empty if weight_decay != 0,
# and capturable doesn't work for CPU tensors
@ -657,7 +678,7 @@ class TestOptimRenewed(TestCase):
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
if optim_cls.__name__ == "AdamW":
kwargs["lr"] = torch.tensor(1e-4) if isinstance(kwargs.get("lr", 1e-4), torch.Tensor) else 1e-4
kwargs["lr"] = torch.tensor(1e-5) if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) else 1e-5
if kwargs.get("differentiable", False):
params = [param.clone()]
@ -684,6 +705,10 @@ class TestOptimRenewed(TestCase):
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
params = [Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) for _ in range(2)]
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
optimizer.__repr__()
@ -706,6 +731,10 @@ class TestOptimRenewed(TestCase):
return loss
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
@ -749,6 +778,10 @@ class TestOptimRenewed(TestCase):
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
@ -811,9 +844,8 @@ class TestOptimRenewed(TestCase):
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
# See https://github.com/pytorch/pytorch/issues/117836 for Adamax
# See https://github.com/pytorch/pytorch/issues/118230 for RAdam
if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
and not kwargs.get("foreach", False)):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
@ -843,6 +875,10 @@ class TestOptimRenewed(TestCase):
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
for optim_input in cpu_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
for p in params:
p.grad = torch.randn_like(p)
@ -906,6 +942,10 @@ class TestOptimRenewed(TestCase):
return {k for k in obj.__dict__ if not k.startswith("_")}
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optimizer = optim_cls(params, **optim_input.kwargs)
# Make some state

View File

@ -1568,7 +1568,7 @@ class TorchPatcher:
}
excluded_single_tensor = {
radam, # https://github.com/pytorch/pytorch/issues/117807
radam, # https://github.com/pytorch/pytorch/issues/118230
}
for opt_mod in optimizer_modules:

View File

@ -74,7 +74,10 @@ class Adam(Optimizer):
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
if group['capturable'] or group['fused']
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
def _init_group(
self,

View File

@ -34,6 +34,9 @@ class Adamax(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if foreach is False and capturable:
raise ValueError("Capturable not supported with single tensor Adamax")
defaults = dict(
lr=lr,
betas=betas,
@ -53,13 +56,12 @@ class Adamax(Optimizer):
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
has_complex = False
@ -265,6 +267,8 @@ def _single_tensor_adamax(
capturable: bool,
has_complex: bool,
):
if capturable:
raise RuntimeError("capturable is not supported for single tensor Adamax (when foreach=False)")
for i, param in enumerate(params):
grad = grads[i]
@ -272,6 +276,7 @@ def _single_tensor_adamax(
exp_avg = exp_avgs[i]
exp_inf = exp_infs[i]
step_t = state_steps[i]
# update step
step_t += 1
@ -328,6 +333,11 @@ def _multi_tensor_adamax(
if len(params) == 0:
return
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if (not torch._utils.is_compiling() and capturable
and not all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps))):
raise RuntimeError("If capturable=True, params and state_steps must be CUDA tensors.")
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
if has_complex:

View File

@ -83,7 +83,10 @@ class AdamW(Optimizer):
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
if group['capturable'] or group['fused']
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
def _init_group(
self,

View File

@ -34,7 +34,7 @@ class ASGD(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if foreach is False and capturable:
if foreach is False and capturable and not is_compiling():
raise ValueError("Capturable not supported with single tensor ASGD")
defaults = dict(
@ -57,25 +57,20 @@ class ASGD(Optimizer):
group.setdefault("maximize", False)
group.setdefault("differentiable", False)
group.setdefault("capturable", False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["eta"]
)
if not eta_is_tensor:
for s in state_values:
s["eta"] = torch.tensor(s["eta"], dtype=_get_scalar_dtype())
mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["mu"]
)
if not mu_is_tensor:
for s in state_values:
s["mu"] = torch.tensor(float(s["mu"]), dtype=_get_scalar_dtype())
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0:
if not torch.is_tensor(p_state['step']):
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
if not torch.is_tensor(p_state["eta"]):
p_state["eta"] = (torch.tensor(p_state["eta"], dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"] else torch.tensor(p_state["eta"], dtype=_get_scalar_dtype()))
if not torch.is_tensor(p_state["mu"]):
p_state["mu"] = (torch.tensor(p_state["mu"], dtype=_get_scalar_dtype(), device=p.device)
if group["capturable"] else torch.tensor(p_state["mu"], dtype=_get_scalar_dtype()))
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
has_complex = False
@ -206,8 +201,6 @@ def asgd(
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
if capturable and not is_compiling():
raise RuntimeError("Capturable not supported with single tensor ASGD")
func = _single_tensor_asgd
func(
@ -247,6 +240,9 @@ def _single_tensor_asgd(
capturable: bool,
has_complex: bool,
):
if capturable and not is_compiling():
raise RuntimeError("capturable is not supported for single tensor ASGD (when foreach=False)")
for i, param in enumerate(params):
grad = grads[i]
grad = grad if not maximize else -grad
@ -304,12 +300,17 @@ def _multi_tensor_asgd(
capturable: bool,
has_complex: bool,
):
if len(params) == 0:
return
assert not differentiable, "_foreach ops don't support autograd"
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(p.is_cuda and mu.is_cuda and eta.is_cuda and step.is_cuda
for p, mu, eta, step in zip(params, mus, etas, state_steps)), \
"If capturable=True, params, mu_products, and state_steps must be CUDA tensors."
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():

View File

@ -37,15 +37,18 @@ class NAdam(Optimizer):
group.setdefault('capturable', False)
group.setdefault('differentiable', False)
group.setdefault('decoupled_weight_decay', False)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']), dtype=_get_scalar_dtype())
mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
if not mu_product_is_tensor:
for s in state_values:
s['mu_product'] = torch.tensor(s['mu_product'], dtype=_get_scalar_dtype())
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0:
if not torch.is_tensor(p_state['step']):
step_val = float(p_state["step"])
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
if not torch.is_tensor(p_state['mu_product']):
mu_prod_val = p_state["mu_product"]
p_state["mu_product"] = (torch.tensor(mu_prod_val, dtype=_get_scalar_dtype(), device=p.device)
if group['capturable'] else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()))
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
has_complex = False

View File

@ -44,6 +44,10 @@ class RAdam(Optimizer):
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if foreach is False and capturable:
raise ValueError("Capturable not supported with single tensor RAdam")
defaults = dict(
lr=lr,
betas=betas,
@ -208,7 +212,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm.
decay as in AdamW to obtain RAdamW (default: False)
{_foreach_doc}
{_differentiable_doc}
{_capturable_doc}
{_capturable_doc} For RAdam, capturable is only supported when foreach=True.
.. _On the variance of the adaptive learning rate and beyond:
https://arxiv.org/abs/1908.03265
@ -297,7 +301,7 @@ def _single_tensor_radam(
has_complex: bool,
):
if capturable:
raise RuntimeError("capturable is not supported for single tensor radam")
raise RuntimeError("capturable is not supported for single tensor RAdam (when foreach=False)")
for i, param in enumerate(params):
grad = grads[i]

View File

@ -114,6 +114,8 @@ class OptimizerInfo:
supports_param_groups: bool = True,
# whether the optimizer supports parameters on multiple devices
supports_multiple_devices: bool = True,
# whether the optimizer ONLY supports capturable on foreach vs. both foreach and forloop
only_supports_capturable_on_foreach: bool = False,
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
optim_error_inputs_func=None, # Function to generate optim inputs that error
@ -126,6 +128,7 @@ class OptimizerInfo:
self.step_requires_closure = step_requires_closure
self.supports_param_groups = supports_param_groups
self.supports_multiple_devices = supports_multiple_devices
self.only_supports_capturable_on_foreach = only_supports_capturable_on_foreach
self.decorators = (
*(decorators if decorators else []),
*(skips if skips else []),
@ -262,7 +265,7 @@ def get_error_inputs_for_all_optims(device, dtype):
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
def optim_inputs_func_adadelta(device=None):
def optim_inputs_func_adadelta(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
@ -297,7 +300,7 @@ def optim_error_inputs_func_adadelta(device, dtype):
return error_inputs
def optim_inputs_func_adagrad(device=None):
def optim_inputs_func_adagrad(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
@ -341,7 +344,7 @@ def optim_error_inputs_func_adagrad(device, dtype):
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
# with all implementation code paths...
def optim_inputs_func_adam(device=None):
def optim_inputs_func_adam(device):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -370,7 +373,7 @@ def optim_inputs_func_adam(device=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
),
] + (cuda_supported_configs if str(device) == "cuda" else [])
] + (cuda_supported_configs if "cuda" in str(device) else [])
def optim_error_inputs_func_adam(device, dtype):
@ -405,7 +408,7 @@ def optim_error_inputs_func_adam(device, dtype):
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
),
]
if str(device) == "cuda":
if "cuda" in str(device):
sample_tensor = torch.empty((), device=device, dtype=dtype)
error_inputs += [
ErrorOptimizerInput(
@ -430,7 +433,7 @@ def optim_error_inputs_func_adam(device, dtype):
return error_inputs
def optim_inputs_func_adamax(device=None):
def optim_inputs_func_adamax(device):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -461,11 +464,23 @@ def optim_inputs_func_adamax(device=None):
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
] + (cuda_supported_configs if str(device) == "cuda" else [])
] + (cuda_supported_configs if "cuda" in str(device) else [])
def optim_error_inputs_func_adamax(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if "cuda" in str(device):
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(foreach=False, capturable=True),
desc="single tensor capturable not supported",
),
error_type=ValueError,
error_regex="Capturable not supported with single tensor Adamax",
)
]
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
@ -481,15 +496,33 @@ def optim_error_inputs_func_adamax(device, dtype):
return error_inputs
def optim_inputs_func_adamw(device=None):
return optim_inputs_func_adam(device=device)
def optim_inputs_func_adamw(device):
return optim_inputs_func_adam(device)
def optim_error_inputs_func_adamw(device, dtype):
return optim_error_inputs_func_adam(device, dtype)
def optim_inputs_func_asgd(device=None):
def optim_inputs_func_asgd(device):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"maximize": True, "capturable": True},
desc="maximize, capturable",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "capturable": True},
desc="weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
desc="maximize, weight_decay, capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
@ -501,13 +534,25 @@ def optim_inputs_func_asgd(device=None):
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
desc="maximize, nonzero weight_decay",
),
]
] + (cuda_supported_configs if "cuda" in str(device) else [])
def optim_error_inputs_func_asgd(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if "cuda" in str(device):
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(foreach=False, capturable=True),
desc="single tensor capturable not supported",
),
error_type=ValueError,
error_regex="Capturable not supported with single tensor ASGD",
)
]
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
@ -523,7 +568,7 @@ def optim_error_inputs_func_asgd(device, dtype):
return error_inputs
def optim_inputs_func_lbfgs(device=None):
def optim_inputs_func_lbfgs(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
@ -544,7 +589,7 @@ def optim_error_inputs_func_lbfgs(device, dtype):
# Weird story bro, NAdam and RAdam do not have maximize.
def optim_inputs_func_nadam(device=None):
def optim_inputs_func_nadam(device):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
@ -585,7 +630,7 @@ def optim_inputs_func_nadam(device=None):
},
desc="decoupled_weight_decay",
),
] + (cuda_supported_configs if str(device) == "cuda" else [])
] + (cuda_supported_configs if "cuda" in str(device) else [])
def optim_error_inputs_func_nadam(device, dtype):
@ -653,6 +698,18 @@ def optim_inputs_func_radam(device=None):
def optim_error_inputs_func_radam(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if "cuda" in str(device):
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(foreach=False, capturable=True),
desc="single tensor capturable not supported",
),
error_type=ValueError,
error_regex="Capturable not supported with single tensor RAdam",
),
]
if str(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
@ -677,7 +734,7 @@ def optim_error_inputs_func_radam(device, dtype):
return error_inputs
def optim_inputs_func_rmsprop(device=None):
def optim_inputs_func_rmsprop(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
@ -724,7 +781,7 @@ def optim_error_inputs_func_rmsprop(device, dtype):
return error_inputs
def optim_inputs_func_rprop(device=None):
def optim_inputs_func_rprop(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
@ -757,7 +814,7 @@ def optim_error_inputs_func_rprop(device, dtype):
return error_inputs
def optim_inputs_func_sgd(device=None):
def optim_inputs_func_sgd(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
@ -802,7 +859,7 @@ def optim_error_inputs_func_sgd(device, dtype):
return error_inputs
def optim_inputs_func_sparseadam(device=None):
def optim_inputs_func_sparseadam(device):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
@ -879,6 +936,14 @@ def optim_error_inputs_func_sparseadam(device, dtype):
return error_inputs
def _get_device_type(device: Union[str, torch.device]) -> str:
# Returns the device type as a string, e.g., "cpu" or "cuda"
if isinstance(device, torch.device):
device = str(device.type)
assert isinstance(device, str)
return device.split(":")[0]
def _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=()
) -> List[OptimizerInput]:
@ -897,14 +962,20 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
x in ["foreach", "fused", "differentiable"] for x in skip
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_inputs = optim_info.optim_inputs_func(device)
supported_impls = tuple(
x
for x in optim_info.supported_impls
if x not in skip
and (str(device) in _get_fused_kernels_supported_devices() or x != "fused")
and (str(device) in _get_foreach_kernels_supported_devices() or x != "foreach")
and (
_get_device_type(device) in _get_fused_kernels_supported_devices()
or x != "fused"
)
and (
_get_device_type(device) in _get_foreach_kernels_supported_devices()
or x != "foreach"
)
)
all_optim_inputs = []
@ -1131,6 +1202,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adamax,
optim_error_inputs_func=optim_error_inputs_func_adamax,
supported_impls=("foreach", "differentiable"),
only_supports_capturable_on_foreach=True, # Remove this line when #117836 is done!
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1197,6 +1269,62 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
DecorateInfo(
skipIfTorchDynamo(
"cpu fails due to #115607; both devices fail cuz #117836"
),
"TestOptimRenewed",
"test_can_load_older_state_dict",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_step_is_noop_when_params_have_no_grad",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_load_nontensor_step",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_param_groups_weight_decay",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
),
"TestOptimRenewed",
"test_mixed_device_dtype",
),
),
),
OptimizerInfo(
@ -1267,6 +1395,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_asgd,
optim_error_inputs_func=optim_error_inputs_func_asgd,
supported_impls=("foreach", "differentiable"),
only_supports_capturable_on_foreach=True, # Remove this line when #116052 is done!
skips=(
DecorateInfo(
skipIfTorchDynamo(
@ -1455,6 +1584,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_radam,
optim_error_inputs_func=optim_error_inputs_func_radam,
supported_impls=("foreach", "differentiable"),
only_supports_capturable_on_foreach=True, # Remove this line when #118230 is done!
skips=(
DecorateInfo(
skipIfTorchDynamo(
@ -1540,6 +1670,13 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_complex",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
@ -1568,13 +1705,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
DecorateInfo(
skipIfTorchDynamo(
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"