mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
8b00e5aa12
commit
b5ba80828f
@ -44,6 +44,10 @@ def get_optimizer_step(opt, closure=None):
|
||||
|
||||
|
||||
def make_test(optim_cls, closure=None, **kwargs):
|
||||
# Remove this conditional when #118230 is fixed
|
||||
if optim_cls.__name__ == "Adamax":
|
||||
kwargs["foreach"] = True
|
||||
|
||||
opt = optim_cls(model.parameters(), **kwargs)
|
||||
|
||||
def test_fn(self):
|
||||
|
@ -74,10 +74,10 @@ KERNEL_COUNTS = {
|
||||
SGD: KernelCounts(multitensor=2, singletensor=8),
|
||||
RAdam: KernelCounts(
|
||||
multitensor=2, singletensor=None
|
||||
), # Single tensor eager needs to be refactored to enable tracing
|
||||
), # Single tensor eager needs to be refactored to enable tracing (#118230)
|
||||
Adamax: KernelCounts(
|
||||
multitensor=2, singletensor=None
|
||||
), # Single tensor eager needs to be refactored to enable tracing
|
||||
), # Single tensor eager needs to be refactored to enable tracing (#117836)
|
||||
}
|
||||
|
||||
|
||||
@ -87,12 +87,9 @@ def build_compiled_opt_kwarg_db():
|
||||
if optim_info.optim_cls not in KERNEL_COUNTS:
|
||||
continue
|
||||
|
||||
for optim_inputs in optim_info.optim_inputs_func():
|
||||
for device in ["cpu", "cuda"]:
|
||||
for optim_inputs in optim_info.optim_inputs_func(device):
|
||||
for foreach in [True, False]:
|
||||
if device == "cpu" and "capturable" in optim_inputs.kwargs:
|
||||
continue
|
||||
|
||||
kwargs = dict(optim_inputs.kwargs)
|
||||
name = (
|
||||
f"test_{optim_info.optim_cls.__name__.lower()}"
|
||||
@ -107,7 +104,21 @@ def build_compiled_opt_kwarg_db():
|
||||
name += f"_{device}"
|
||||
|
||||
# Eager for-loop impl doesn't support capturable ASGD
|
||||
if name == "test_asgd_capturable_cuda":
|
||||
if name in [
|
||||
"test_asgd_capturable_cuda",
|
||||
"test_asgd_maximize_capturable_cuda",
|
||||
"test_asgd_weight_decay_capturable_cuda",
|
||||
"test_asgd_weight_decay_maximize_capturable_cuda",
|
||||
]:
|
||||
continue
|
||||
|
||||
# Adam(W) capturable cudagraphs manager is unexpectedly None, #119026
|
||||
if name in [
|
||||
"test_adam_amsgrad_capturable_cuda",
|
||||
"test_adam_foreach_amsgrad_capturable_cuda",
|
||||
"test_adamw_amsgrad_capturable_cuda",
|
||||
"test_adamw_foreach_amsgrad_capturable_cuda",
|
||||
]:
|
||||
continue
|
||||
|
||||
kwargs["foreach"] = foreach
|
||||
|
@ -583,9 +583,9 @@ class TestLazyModules(TestCase):
|
||||
|
||||
@suppress_warnings
|
||||
def test_optimizer_pass(self):
|
||||
# Add Adamax and RAdam when #118230 and #117836 are complete
|
||||
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
|
||||
torch.optim.AdamW, torch.optim.Adamax,
|
||||
torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
||||
torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
||||
torch.optim.RMSprop, torch.optim.LBFGS]
|
||||
|
||||
def run_step(module, optim):
|
||||
|
@ -19,10 +19,8 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri
|
||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||
|
||||
|
||||
def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
|
||||
# Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
|
||||
if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
|
||||
# Radam does not support capturable single tensor
|
||||
def _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs):
|
||||
if optim_info.only_supports_capturable_on_foreach and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
|
||||
kwargs["capturable"] = False
|
||||
|
||||
@markDynamoStrictTest
|
||||
@ -71,6 +69,9 @@ class TestOptimRenewed(TestCase):
|
||||
for optim_input in optim_inputs:
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
|
||||
_force_capturable_False_for_unsupported_single_tensor(optim_info, optim_input.kwargs)
|
||||
|
||||
if contiguous:
|
||||
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
||||
@ -79,8 +80,6 @@ class TestOptimRenewed(TestCase):
|
||||
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
|
||||
input = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||
|
||||
def closure():
|
||||
@ -109,13 +108,14 @@ class TestOptimRenewed(TestCase):
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
optim_inputs = optim_info.optim_inputs_func(device="cuda")
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
for optim_input in optim_inputs:
|
||||
if "foreach" in optim_info.supported_impls:
|
||||
optim_input.kwargs["foreach"] = False # force forloop
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
||||
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
||||
@ -148,13 +148,19 @@ class TestOptimRenewed(TestCase):
|
||||
def test_complex(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
# Also skip fused, since our fused kernels do not support complex
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info, skip=("differentiable", "fused"))
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
# Last param is intentionally real to test that we can mix real and complex
|
||||
complex_params = [
|
||||
torch.randn(10, 5, dtype=dtype, requires_grad=True),
|
||||
torch.randn(10, dtype=dtype, requires_grad=True),
|
||||
torch.randn(10, 5, dtype=torch.float32, requires_grad=True),
|
||||
torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
|
||||
torch.randn(10, device=device, dtype=dtype, requires_grad=True),
|
||||
torch.randn(10, 5, device=device, dtype=torch.float32, requires_grad=True),
|
||||
]
|
||||
real_params = [
|
||||
(
|
||||
@ -164,8 +170,7 @@ class TestOptimRenewed(TestCase):
|
||||
)
|
||||
for param in complex_params
|
||||
]
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
|
||||
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
|
||||
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
|
||||
real_steps = []
|
||||
@ -234,14 +239,13 @@ class TestOptimRenewed(TestCase):
|
||||
for optim_input in optim_inputs:
|
||||
updated_params, state = [], []
|
||||
kwargs = deepcopy(optim_input.kwargs)
|
||||
if (kwargs.get("capturable", False) and str(device) == "cpu"):
|
||||
if kwargs.get("capturable", False) and str(device) == "cpu":
|
||||
# capturable is not supported on CPU
|
||||
continue
|
||||
for flag_value in (False, True):
|
||||
kwargs[flag] = flag_value
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||
|
||||
input = torch.tensor(
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
|
||||
@ -344,7 +348,10 @@ class TestOptimRenewed(TestCase):
|
||||
for optim_input in optim_inputs:
|
||||
updated_params, state = [], []
|
||||
kwargs = deepcopy(optim_input.kwargs)
|
||||
if kwargs.get("capturable", False) and str(device) == "cpu":
|
||||
|
||||
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||
|
||||
if kwargs.get("capturable", False) and str(device) == "cpu" :
|
||||
# capturable is not supported on CPU
|
||||
continue
|
||||
for use_impl in (False, True):
|
||||
@ -357,8 +364,6 @@ class TestOptimRenewed(TestCase):
|
||||
p_clone.grad = p.grad.clone().detach()
|
||||
params_clone.append(p_clone)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
optimizer = optim_cls(params_clone, **kwargs)
|
||||
for _ in range(kIterations):
|
||||
optimizer.step()
|
||||
@ -393,6 +398,7 @@ class TestOptimRenewed(TestCase):
|
||||
# default dtype is higher prec float64
|
||||
old_default_dtype = torch.get_default_dtype()
|
||||
for default_dtype in [torch.float64, torch.float16]:
|
||||
try:
|
||||
torch.set_default_dtype(default_dtype)
|
||||
self._test_derived_optimizers(
|
||||
device,
|
||||
@ -402,6 +408,7 @@ class TestOptimRenewed(TestCase):
|
||||
reduced_precision=default_dtype == torch.float16,
|
||||
assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
|
||||
)
|
||||
finally:
|
||||
torch.set_default_dtype(old_default_dtype)
|
||||
|
||||
|
||||
@ -431,8 +438,7 @@ class TestOptimRenewed(TestCase):
|
||||
for flag_value in (False, True):
|
||||
kwargs["foreach"] = flag_value
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/118230
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
||||
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||
|
||||
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
|
||||
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
|
||||
@ -539,6 +545,11 @@ class TestOptimRenewed(TestCase):
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
for optim_input in all_optim_inputs:
|
||||
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
weight_kwargs = optim_input.kwargs
|
||||
bias_kwargs = deepcopy(optim_input.kwargs)
|
||||
bias_kwargs["weight_decay"] = 0.0
|
||||
@ -575,6 +586,11 @@ class TestOptimRenewed(TestCase):
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
for optim_input in all_optim_inputs:
|
||||
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
|
||||
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
|
||||
optim_input.kwargs["lr"] = 1e-3
|
||||
@ -630,10 +646,12 @@ class TestOptimRenewed(TestCase):
|
||||
return torch.tensor([1], device=device, dtype=dtype)
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
optimizer.step(closure)
|
||||
self.assertEqual(old_params, params)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
@ -648,7 +666,10 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
||||
|
||||
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
|
||||
and not kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
# params will decay even if grads are empty if weight_decay != 0,
|
||||
# and capturable doesn't work for CPU tensors
|
||||
@ -657,7 +678,7 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
|
||||
if optim_cls.__name__ == "AdamW":
|
||||
kwargs["lr"] = torch.tensor(1e-4) if isinstance(kwargs.get("lr", 1e-4), torch.Tensor) else 1e-4
|
||||
kwargs["lr"] = torch.tensor(1e-5) if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) else 1e-5
|
||||
|
||||
if kwargs.get("differentiable", False):
|
||||
params = [param.clone()]
|
||||
@ -684,6 +705,10 @@ class TestOptimRenewed(TestCase):
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
params = [Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) for _ in range(2)]
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
optimizer.__repr__()
|
||||
|
||||
@ -706,6 +731,10 @@ class TestOptimRenewed(TestCase):
|
||||
return loss
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
|
||||
|
||||
@ -749,6 +778,10 @@ class TestOptimRenewed(TestCase):
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
torch.manual_seed(1)
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(4, 2, 1, stride=2),
|
||||
@ -811,9 +844,8 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
kwargs = optim_input.kwargs
|
||||
# See https://github.com/pytorch/pytorch/issues/117836 for Adamax
|
||||
# See https://github.com/pytorch/pytorch/issues/118230 for RAdam
|
||||
if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
|
||||
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
|
||||
and not kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
@ -843,6 +875,10 @@ class TestOptimRenewed(TestCase):
|
||||
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
|
||||
|
||||
for optim_input in cpu_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
|
||||
for p in params:
|
||||
p.grad = torch.randn_like(p)
|
||||
@ -906,6 +942,10 @@ class TestOptimRenewed(TestCase):
|
||||
return {k for k in obj.__dict__ if not k.startswith("_")}
|
||||
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||
|
||||
# Make some state
|
||||
|
@ -1568,7 +1568,7 @@ class TorchPatcher:
|
||||
}
|
||||
|
||||
excluded_single_tensor = {
|
||||
radam, # https://github.com/pytorch/pytorch/issues/117807
|
||||
radam, # https://github.com/pytorch/pytorch/issues/118230
|
||||
}
|
||||
|
||||
for opt_mod in optimizer_modules:
|
||||
|
@ -74,7 +74,10 @@ class Adam(Optimizer):
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
|
||||
if group['capturable'] or group['fused']
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
|
@ -34,6 +34,9 @@ class Adamax(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
if foreach is False and capturable:
|
||||
raise ValueError("Capturable not supported with single tensor Adamax")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
@ -53,13 +56,12 @@ class Adamax(Optimizer):
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("capturable", False)
|
||||
state_values = list(self.state.values())
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]["step"]
|
||||
)
|
||||
if not step_is_tensor:
|
||||
for s in state_values:
|
||||
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
|
||||
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
|
||||
has_complex = False
|
||||
@ -265,6 +267,8 @@ def _single_tensor_adamax(
|
||||
capturable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
if capturable:
|
||||
raise RuntimeError("capturable is not supported for single tensor Adamax (when foreach=False)")
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
@ -272,6 +276,7 @@ def _single_tensor_adamax(
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_inf = exp_infs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
@ -328,6 +333,11 @@ def _multi_tensor_adamax(
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if (not torch._utils.is_compiling() and capturable
|
||||
and not all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps))):
|
||||
raise RuntimeError("If capturable=True, params and state_steps must be CUDA tensors.")
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
|
||||
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
|
||||
if has_complex:
|
||||
|
@ -83,7 +83,10 @@ class AdamW(Optimizer):
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
|
||||
if group['capturable'] or group['fused']
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
|
@ -34,7 +34,7 @@ class ASGD(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
if foreach is False and capturable:
|
||||
if foreach is False and capturable and not is_compiling():
|
||||
raise ValueError("Capturable not supported with single tensor ASGD")
|
||||
|
||||
defaults = dict(
|
||||
@ -57,25 +57,20 @@ class ASGD(Optimizer):
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("capturable", False)
|
||||
state_values = list(self.state.values())
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]["step"]
|
||||
)
|
||||
if not step_is_tensor:
|
||||
for s in state_values:
|
||||
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
|
||||
eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]["eta"]
|
||||
)
|
||||
if not eta_is_tensor:
|
||||
for s in state_values:
|
||||
s["eta"] = torch.tensor(s["eta"], dtype=_get_scalar_dtype())
|
||||
mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]["mu"]
|
||||
)
|
||||
if not mu_is_tensor:
|
||||
for s in state_values:
|
||||
s["mu"] = torch.tensor(float(s["mu"]), dtype=_get_scalar_dtype())
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0:
|
||||
if not torch.is_tensor(p_state['step']):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
if not torch.is_tensor(p_state["eta"]):
|
||||
p_state["eta"] = (torch.tensor(p_state["eta"], dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group["capturable"] else torch.tensor(p_state["eta"], dtype=_get_scalar_dtype()))
|
||||
if not torch.is_tensor(p_state["mu"]):
|
||||
p_state["mu"] = (torch.tensor(p_state["mu"], dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group["capturable"] else torch.tensor(p_state["mu"], dtype=_get_scalar_dtype()))
|
||||
|
||||
|
||||
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
|
||||
has_complex = False
|
||||
@ -206,8 +201,6 @@ def asgd(
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_asgd
|
||||
else:
|
||||
if capturable and not is_compiling():
|
||||
raise RuntimeError("Capturable not supported with single tensor ASGD")
|
||||
func = _single_tensor_asgd
|
||||
|
||||
func(
|
||||
@ -247,6 +240,9 @@ def _single_tensor_asgd(
|
||||
capturable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
if capturable and not is_compiling():
|
||||
raise RuntimeError("capturable is not supported for single tensor ASGD (when foreach=False)")
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
grad = grad if not maximize else -grad
|
||||
@ -304,12 +300,17 @@ def _multi_tensor_asgd(
|
||||
capturable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert all(p.is_cuda and mu.is_cuda and eta.is_cuda and step.is_cuda
|
||||
for p, mu, eta, step in zip(params, mus, etas, state_steps)), \
|
||||
"If capturable=True, params, mu_products, and state_steps must be CUDA tensors."
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
|
||||
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
|
||||
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
|
||||
|
@ -37,15 +37,18 @@ class NAdam(Optimizer):
|
||||
group.setdefault('capturable', False)
|
||||
group.setdefault('differentiable', False)
|
||||
group.setdefault('decoupled_weight_decay', False)
|
||||
state_values = list(self.state.values())
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
||||
if not step_is_tensor:
|
||||
for s in state_values:
|
||||
s['step'] = torch.tensor(float(s['step']), dtype=_get_scalar_dtype())
|
||||
mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
|
||||
if not mu_product_is_tensor:
|
||||
for s in state_values:
|
||||
s['mu_product'] = torch.tensor(s['mu_product'], dtype=_get_scalar_dtype())
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0:
|
||||
if not torch.is_tensor(p_state['step']):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||
if not torch.is_tensor(p_state['mu_product']):
|
||||
mu_prod_val = p_state["mu_product"]
|
||||
p_state["mu_product"] = (torch.tensor(mu_prod_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||
if group['capturable'] else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()))
|
||||
|
||||
|
||||
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
|
||||
has_complex = False
|
||||
|
@ -44,6 +44,10 @@ class RAdam(Optimizer):
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
if foreach is False and capturable:
|
||||
raise ValueError("Capturable not supported with single tensor RAdam")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
@ -208,7 +212,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm.
|
||||
decay as in AdamW to obtain RAdamW (default: False)
|
||||
{_foreach_doc}
|
||||
{_differentiable_doc}
|
||||
{_capturable_doc}
|
||||
{_capturable_doc} For RAdam, capturable is only supported when foreach=True.
|
||||
|
||||
.. _On the variance of the adaptive learning rate and beyond:
|
||||
https://arxiv.org/abs/1908.03265
|
||||
@ -297,7 +301,7 @@ def _single_tensor_radam(
|
||||
has_complex: bool,
|
||||
):
|
||||
if capturable:
|
||||
raise RuntimeError("capturable is not supported for single tensor radam")
|
||||
raise RuntimeError("capturable is not supported for single tensor RAdam (when foreach=False)")
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
|
@ -114,6 +114,8 @@ class OptimizerInfo:
|
||||
supports_param_groups: bool = True,
|
||||
# whether the optimizer supports parameters on multiple devices
|
||||
supports_multiple_devices: bool = True,
|
||||
# whether the optimizer ONLY supports capturable on foreach vs. both foreach and forloop
|
||||
only_supports_capturable_on_foreach: bool = False,
|
||||
skips=(), # Indicates which tests to skip
|
||||
decorators=None, # Additional decorators to apply to generated tests
|
||||
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
||||
@ -126,6 +128,7 @@ class OptimizerInfo:
|
||||
self.step_requires_closure = step_requires_closure
|
||||
self.supports_param_groups = supports_param_groups
|
||||
self.supports_multiple_devices = supports_multiple_devices
|
||||
self.only_supports_capturable_on_foreach = only_supports_capturable_on_foreach
|
||||
self.decorators = (
|
||||
*(decorators if decorators else []),
|
||||
*(skips if skips else []),
|
||||
@ -262,7 +265,7 @@ def get_error_inputs_for_all_optims(device, dtype):
|
||||
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
|
||||
|
||||
|
||||
def optim_inputs_func_adadelta(device=None):
|
||||
def optim_inputs_func_adadelta(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
@ -297,7 +300,7 @@ def optim_error_inputs_func_adadelta(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adagrad(device=None):
|
||||
def optim_inputs_func_adagrad(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
@ -341,7 +344,7 @@ def optim_error_inputs_func_adagrad(device, dtype):
|
||||
|
||||
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
|
||||
# with all implementation code paths...
|
||||
def optim_inputs_func_adam(device=None):
|
||||
def optim_inputs_func_adam(device):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -370,7 +373,7 @@ def optim_inputs_func_adam(device=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
|
||||
),
|
||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
def optim_error_inputs_func_adam(device, dtype):
|
||||
@ -405,7 +408,7 @@ def optim_error_inputs_func_adam(device, dtype):
|
||||
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
||||
),
|
||||
]
|
||||
if str(device) == "cuda":
|
||||
if "cuda" in str(device):
|
||||
sample_tensor = torch.empty((), device=device, dtype=dtype)
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
@ -430,7 +433,7 @@ def optim_error_inputs_func_adam(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adamax(device=None):
|
||||
def optim_inputs_func_adamax(device):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -461,11 +464,23 @@ def optim_inputs_func_adamax(device=None):
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
def optim_error_inputs_func_adamax(device, dtype):
|
||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||
if "cuda" in str(device):
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(foreach=False, capturable=True),
|
||||
desc="single tensor capturable not supported",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="Capturable not supported with single tensor Adamax",
|
||||
)
|
||||
]
|
||||
if str(device) == "cpu":
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
@ -481,15 +496,33 @@ def optim_error_inputs_func_adamax(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adamw(device=None):
|
||||
return optim_inputs_func_adam(device=device)
|
||||
def optim_inputs_func_adamw(device):
|
||||
return optim_inputs_func_adam(device)
|
||||
|
||||
|
||||
def optim_error_inputs_func_adamw(device, dtype):
|
||||
return optim_error_inputs_func_adam(device, dtype)
|
||||
|
||||
|
||||
def optim_inputs_func_asgd(device=None):
|
||||
def optim_inputs_func_asgd(device):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"maximize": True, "capturable": True},
|
||||
desc="maximize, capturable",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "capturable": True},
|
||||
desc="weight_decay, capturable",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
|
||||
desc="maximize, weight_decay, capturable",
|
||||
),
|
||||
]
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||||
@ -501,13 +534,25 @@ def optim_inputs_func_asgd(device=None):
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
desc="maximize, nonzero weight_decay",
|
||||
),
|
||||
]
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
def optim_error_inputs_func_asgd(device, dtype):
|
||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||
if "cuda" in str(device):
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(foreach=False, capturable=True),
|
||||
desc="single tensor capturable not supported",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="Capturable not supported with single tensor ASGD",
|
||||
)
|
||||
]
|
||||
if str(device) == "cpu":
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
@ -523,7 +568,7 @@ def optim_error_inputs_func_asgd(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_lbfgs(device=None):
|
||||
def optim_inputs_func_lbfgs(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||
@ -544,7 +589,7 @@ def optim_error_inputs_func_lbfgs(device, dtype):
|
||||
|
||||
|
||||
# Weird story bro, NAdam and RAdam do not have maximize.
|
||||
def optim_inputs_func_nadam(device=None):
|
||||
def optim_inputs_func_nadam(device):
|
||||
cuda_supported_configs = [
|
||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||
OptimizerInput(
|
||||
@ -585,7 +630,7 @@ def optim_inputs_func_nadam(device=None):
|
||||
},
|
||||
desc="decoupled_weight_decay",
|
||||
),
|
||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
def optim_error_inputs_func_nadam(device, dtype):
|
||||
@ -653,6 +698,18 @@ def optim_inputs_func_radam(device=None):
|
||||
|
||||
def optim_error_inputs_func_radam(device, dtype):
|
||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||
if "cuda" in str(device):
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(foreach=False, capturable=True),
|
||||
desc="single tensor capturable not supported",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="Capturable not supported with single tensor RAdam",
|
||||
),
|
||||
]
|
||||
if str(device) == "cpu":
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
@ -677,7 +734,7 @@ def optim_error_inputs_func_radam(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_rmsprop(device=None):
|
||||
def optim_inputs_func_rmsprop(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
||||
@ -724,7 +781,7 @@ def optim_error_inputs_func_rmsprop(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_rprop(device=None):
|
||||
def optim_inputs_func_rprop(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
|
||||
@ -757,7 +814,7 @@ def optim_error_inputs_func_rprop(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_sgd(device=None):
|
||||
def optim_inputs_func_sgd(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
||||
@ -802,7 +859,7 @@ def optim_error_inputs_func_sgd(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_sparseadam(device=None):
|
||||
def optim_inputs_func_sparseadam(device):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
@ -879,6 +936,14 @@ def optim_error_inputs_func_sparseadam(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def _get_device_type(device: Union[str, torch.device]) -> str:
|
||||
# Returns the device type as a string, e.g., "cpu" or "cuda"
|
||||
if isinstance(device, torch.device):
|
||||
device = str(device.type)
|
||||
assert isinstance(device, str)
|
||||
return device.split(":")[0]
|
||||
|
||||
|
||||
def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info, skip=()
|
||||
) -> List[OptimizerInput]:
|
||||
@ -897,14 +962,20 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
x in ["foreach", "fused", "differentiable"] for x in skip
|
||||
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||||
|
||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||
optim_inputs = optim_info.optim_inputs_func(device)
|
||||
|
||||
supported_impls = tuple(
|
||||
x
|
||||
for x in optim_info.supported_impls
|
||||
if x not in skip
|
||||
and (str(device) in _get_fused_kernels_supported_devices() or x != "fused")
|
||||
and (str(device) in _get_foreach_kernels_supported_devices() or x != "foreach")
|
||||
and (
|
||||
_get_device_type(device) in _get_fused_kernels_supported_devices()
|
||||
or x != "fused"
|
||||
)
|
||||
and (
|
||||
_get_device_type(device) in _get_foreach_kernels_supported_devices()
|
||||
or x != "foreach"
|
||||
)
|
||||
)
|
||||
|
||||
all_optim_inputs = []
|
||||
@ -1131,6 +1202,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adamax,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
only_supports_capturable_on_foreach=True, # Remove this line when #117836 is done!
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1197,6 +1269,62 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_deepcopy_copies_all_public_attrs",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"cpu fails due to #115607; both devices fail cuz #117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_can_load_older_state_dict",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_for_zero_grads",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_when_params_have_no_grad",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_load_nontensor_step",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_weight_decay",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_lr",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_state_dict_with_cuda_params",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_mixed_device_dtype",
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
@ -1267,6 +1395,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_asgd,
|
||||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
only_supports_capturable_on_foreach=True, # Remove this line when #116052 is done!
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
@ -1455,6 +1584,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_radam,
|
||||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
only_supports_capturable_on_foreach=True, # Remove this line when #118230 is done!
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
@ -1540,6 +1670,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_complex",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_for_zero_grads",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
@ -1568,13 +1705,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_lr",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_step_is_noop_for_zero_grads",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||
|
Reference in New Issue
Block a user