mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
8b00e5aa12
commit
b5ba80828f
@ -44,6 +44,10 @@ def get_optimizer_step(opt, closure=None):
|
|||||||
|
|
||||||
|
|
||||||
def make_test(optim_cls, closure=None, **kwargs):
|
def make_test(optim_cls, closure=None, **kwargs):
|
||||||
|
# Remove this conditional when #118230 is fixed
|
||||||
|
if optim_cls.__name__ == "Adamax":
|
||||||
|
kwargs["foreach"] = True
|
||||||
|
|
||||||
opt = optim_cls(model.parameters(), **kwargs)
|
opt = optim_cls(model.parameters(), **kwargs)
|
||||||
|
|
||||||
def test_fn(self):
|
def test_fn(self):
|
||||||
|
@ -74,10 +74,10 @@ KERNEL_COUNTS = {
|
|||||||
SGD: KernelCounts(multitensor=2, singletensor=8),
|
SGD: KernelCounts(multitensor=2, singletensor=8),
|
||||||
RAdam: KernelCounts(
|
RAdam: KernelCounts(
|
||||||
multitensor=2, singletensor=None
|
multitensor=2, singletensor=None
|
||||||
), # Single tensor eager needs to be refactored to enable tracing
|
), # Single tensor eager needs to be refactored to enable tracing (#118230)
|
||||||
Adamax: KernelCounts(
|
Adamax: KernelCounts(
|
||||||
multitensor=2, singletensor=None
|
multitensor=2, singletensor=None
|
||||||
), # Single tensor eager needs to be refactored to enable tracing
|
), # Single tensor eager needs to be refactored to enable tracing (#117836)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -87,12 +87,9 @@ def build_compiled_opt_kwarg_db():
|
|||||||
if optim_info.optim_cls not in KERNEL_COUNTS:
|
if optim_info.optim_cls not in KERNEL_COUNTS:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for optim_inputs in optim_info.optim_inputs_func():
|
|
||||||
for device in ["cpu", "cuda"]:
|
for device in ["cpu", "cuda"]:
|
||||||
|
for optim_inputs in optim_info.optim_inputs_func(device):
|
||||||
for foreach in [True, False]:
|
for foreach in [True, False]:
|
||||||
if device == "cpu" and "capturable" in optim_inputs.kwargs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
kwargs = dict(optim_inputs.kwargs)
|
kwargs = dict(optim_inputs.kwargs)
|
||||||
name = (
|
name = (
|
||||||
f"test_{optim_info.optim_cls.__name__.lower()}"
|
f"test_{optim_info.optim_cls.__name__.lower()}"
|
||||||
@ -107,7 +104,21 @@ def build_compiled_opt_kwarg_db():
|
|||||||
name += f"_{device}"
|
name += f"_{device}"
|
||||||
|
|
||||||
# Eager for-loop impl doesn't support capturable ASGD
|
# Eager for-loop impl doesn't support capturable ASGD
|
||||||
if name == "test_asgd_capturable_cuda":
|
if name in [
|
||||||
|
"test_asgd_capturable_cuda",
|
||||||
|
"test_asgd_maximize_capturable_cuda",
|
||||||
|
"test_asgd_weight_decay_capturable_cuda",
|
||||||
|
"test_asgd_weight_decay_maximize_capturable_cuda",
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adam(W) capturable cudagraphs manager is unexpectedly None, #119026
|
||||||
|
if name in [
|
||||||
|
"test_adam_amsgrad_capturable_cuda",
|
||||||
|
"test_adam_foreach_amsgrad_capturable_cuda",
|
||||||
|
"test_adamw_amsgrad_capturable_cuda",
|
||||||
|
"test_adamw_foreach_amsgrad_capturable_cuda",
|
||||||
|
]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
kwargs["foreach"] = foreach
|
kwargs["foreach"] = foreach
|
||||||
|
@ -583,9 +583,9 @@ class TestLazyModules(TestCase):
|
|||||||
|
|
||||||
@suppress_warnings
|
@suppress_warnings
|
||||||
def test_optimizer_pass(self):
|
def test_optimizer_pass(self):
|
||||||
|
# Add Adamax and RAdam when #118230 and #117836 are complete
|
||||||
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
|
optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam,
|
||||||
torch.optim.AdamW, torch.optim.Adamax,
|
torch.optim.AdamW, torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
||||||
torch.optim.ASGD, torch.optim.SGD, torch.optim.Rprop,
|
|
||||||
torch.optim.RMSprop, torch.optim.LBFGS]
|
torch.optim.RMSprop, torch.optim.LBFGS]
|
||||||
|
|
||||||
def run_step(module, optim):
|
def run_step(module, optim):
|
||||||
|
@ -19,10 +19,8 @@ from torch.testing._internal.common_utils import markDynamoStrictTest, parametri
|
|||||||
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
FP16_REDUCED_PRECISION = {'atol': 1e-5, 'rtol': 1e-4}
|
||||||
|
|
||||||
|
|
||||||
def _make_radam_single_tensor_non_capturable(optim_cls, kwargs):
|
def _force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs):
|
||||||
# Remove this function once https://github.com/pytorch/pytorch/issues/118230 is completed
|
if optim_info.only_supports_capturable_on_foreach and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
|
||||||
if optim_cls == torch.optim.RAdam and not kwargs.get("foreach", False) and kwargs.get("capturable", False):
|
|
||||||
# Radam does not support capturable single tensor
|
|
||||||
kwargs["capturable"] = False
|
kwargs["capturable"] = False
|
||||||
|
|
||||||
@markDynamoStrictTest
|
@markDynamoStrictTest
|
||||||
@ -71,6 +69,9 @@ class TestOptimRenewed(TestCase):
|
|||||||
for optim_input in optim_inputs:
|
for optim_input in optim_inputs:
|
||||||
if "foreach" in optim_info.supported_impls:
|
if "foreach" in optim_info.supported_impls:
|
||||||
optim_input.kwargs["foreach"] = False # force forloop
|
optim_input.kwargs["foreach"] = False # force forloop
|
||||||
|
|
||||||
|
_force_capturable_False_for_unsupported_single_tensor(optim_info, optim_input.kwargs)
|
||||||
|
|
||||||
if contiguous:
|
if contiguous:
|
||||||
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
||||||
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
||||||
@ -79,8 +80,6 @@ class TestOptimRenewed(TestCase):
|
|||||||
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
|
bias = Parameter(torch.randn((10, 2), device=device, dtype=dtype)[..., 0])
|
||||||
input = torch.randn(5, device=device, dtype=dtype)
|
input = torch.randn(5, device=device, dtype=dtype)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
|
||||||
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
|
||||||
|
|
||||||
def closure():
|
def closure():
|
||||||
@ -109,13 +108,14 @@ class TestOptimRenewed(TestCase):
|
|||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
def test_forloop_goes_right_direction_multigpu(self, device, dtype, optim_info):
|
||||||
optim_cls = optim_info.optim_cls
|
optim_cls = optim_info.optim_cls
|
||||||
optim_inputs = optim_info.optim_inputs_func(device="cuda")
|
optim_inputs = optim_info.optim_inputs_func(device=device)
|
||||||
for optim_input in optim_inputs:
|
for optim_input in optim_inputs:
|
||||||
if "foreach" in optim_info.supported_impls:
|
if "foreach" in optim_info.supported_impls:
|
||||||
optim_input.kwargs["foreach"] = False # force forloop
|
optim_input.kwargs["foreach"] = False # force forloop
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
|
||||||
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
|
||||||
@ -148,13 +148,19 @@ class TestOptimRenewed(TestCase):
|
|||||||
def test_complex(self, device, dtype, optim_info):
|
def test_complex(self, device, dtype, optim_info):
|
||||||
optim_cls = optim_info.optim_cls
|
optim_cls = optim_info.optim_cls
|
||||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
# Also skip fused, since our fused kernels do not support complex
|
||||||
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
|
||||||
|
device, dtype, optim_info, skip=("differentiable", "fused"))
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
# Last param is intentionally real to test that we can mix real and complex
|
# Last param is intentionally real to test that we can mix real and complex
|
||||||
complex_params = [
|
complex_params = [
|
||||||
torch.randn(10, 5, dtype=dtype, requires_grad=True),
|
torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
|
||||||
torch.randn(10, dtype=dtype, requires_grad=True),
|
torch.randn(10, device=device, dtype=dtype, requires_grad=True),
|
||||||
torch.randn(10, 5, dtype=torch.float32, requires_grad=True),
|
torch.randn(10, 5, device=device, dtype=torch.float32, requires_grad=True),
|
||||||
]
|
]
|
||||||
real_params = [
|
real_params = [
|
||||||
(
|
(
|
||||||
@ -164,8 +170,7 @@ class TestOptimRenewed(TestCase):
|
|||||||
)
|
)
|
||||||
for param in complex_params
|
for param in complex_params
|
||||||
]
|
]
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
|
||||||
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
|
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
|
||||||
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
|
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
|
||||||
real_steps = []
|
real_steps = []
|
||||||
@ -234,14 +239,13 @@ class TestOptimRenewed(TestCase):
|
|||||||
for optim_input in optim_inputs:
|
for optim_input in optim_inputs:
|
||||||
updated_params, state = [], []
|
updated_params, state = [], []
|
||||||
kwargs = deepcopy(optim_input.kwargs)
|
kwargs = deepcopy(optim_input.kwargs)
|
||||||
if (kwargs.get("capturable", False) and str(device) == "cpu"):
|
if kwargs.get("capturable", False) and str(device) == "cpu":
|
||||||
# capturable is not supported on CPU
|
# capturable is not supported on CPU
|
||||||
continue
|
continue
|
||||||
for flag_value in (False, True):
|
for flag_value in (False, True):
|
||||||
kwargs[flag] = flag_value
|
kwargs[flag] = flag_value
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
|
||||||
|
|
||||||
input = torch.tensor(
|
input = torch.tensor(
|
||||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
|
||||||
@ -344,7 +348,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
for optim_input in optim_inputs:
|
for optim_input in optim_inputs:
|
||||||
updated_params, state = [], []
|
updated_params, state = [], []
|
||||||
kwargs = deepcopy(optim_input.kwargs)
|
kwargs = deepcopy(optim_input.kwargs)
|
||||||
if kwargs.get("capturable", False) and str(device) == "cpu":
|
|
||||||
|
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||||
|
|
||||||
|
if kwargs.get("capturable", False) and str(device) == "cpu" :
|
||||||
# capturable is not supported on CPU
|
# capturable is not supported on CPU
|
||||||
continue
|
continue
|
||||||
for use_impl in (False, True):
|
for use_impl in (False, True):
|
||||||
@ -357,8 +364,6 @@ class TestOptimRenewed(TestCase):
|
|||||||
p_clone.grad = p.grad.clone().detach()
|
p_clone.grad = p.grad.clone().detach()
|
||||||
params_clone.append(p_clone)
|
params_clone.append(p_clone)
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
|
||||||
optimizer = optim_cls(params_clone, **kwargs)
|
optimizer = optim_cls(params_clone, **kwargs)
|
||||||
for _ in range(kIterations):
|
for _ in range(kIterations):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -393,6 +398,7 @@ class TestOptimRenewed(TestCase):
|
|||||||
# default dtype is higher prec float64
|
# default dtype is higher prec float64
|
||||||
old_default_dtype = torch.get_default_dtype()
|
old_default_dtype = torch.get_default_dtype()
|
||||||
for default_dtype in [torch.float64, torch.float16]:
|
for default_dtype in [torch.float64, torch.float16]:
|
||||||
|
try:
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
self._test_derived_optimizers(
|
self._test_derived_optimizers(
|
||||||
device,
|
device,
|
||||||
@ -402,6 +408,7 @@ class TestOptimRenewed(TestCase):
|
|||||||
reduced_precision=default_dtype == torch.float16,
|
reduced_precision=default_dtype == torch.float16,
|
||||||
assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
|
assert_step_dtype=torch.float64 if default_dtype == torch.float64 else torch.float32,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
torch.set_default_dtype(old_default_dtype)
|
torch.set_default_dtype(old_default_dtype)
|
||||||
|
|
||||||
|
|
||||||
@ -431,8 +438,7 @@ class TestOptimRenewed(TestCase):
|
|||||||
for flag_value in (False, True):
|
for flag_value in (False, True):
|
||||||
kwargs["foreach"] = flag_value
|
kwargs["foreach"] = flag_value
|
||||||
|
|
||||||
# https://github.com/pytorch/pytorch/issues/118230
|
_force_capturable_False_for_unsupported_single_tensor(optim_info, kwargs)
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, kwargs)
|
|
||||||
|
|
||||||
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
|
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
|
||||||
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
|
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
|
||||||
@ -539,6 +545,11 @@ class TestOptimRenewed(TestCase):
|
|||||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
weight_kwargs = optim_input.kwargs
|
weight_kwargs = optim_input.kwargs
|
||||||
bias_kwargs = deepcopy(optim_input.kwargs)
|
bias_kwargs = deepcopy(optim_input.kwargs)
|
||||||
bias_kwargs["weight_decay"] = 0.0
|
bias_kwargs["weight_decay"] = 0.0
|
||||||
@ -575,6 +586,11 @@ class TestOptimRenewed(TestCase):
|
|||||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/117836 and #118230
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
|
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
|
||||||
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
|
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
|
||||||
optim_input.kwargs["lr"] = 1e-3
|
optim_input.kwargs["lr"] = 1e-3
|
||||||
@ -630,10 +646,12 @@ class TestOptimRenewed(TestCase):
|
|||||||
return torch.tensor([1], device=device, dtype=dtype)
|
return torch.tensor([1], device=device, dtype=dtype)
|
||||||
|
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||||
optimizer.step(closure)
|
optimizer.step(closure)
|
||||||
self.assertEqual(old_params, params)
|
|
||||||
|
|
||||||
|
|
||||||
@optims(optim_db, dtypes=[torch.float32])
|
@optims(optim_db, dtypes=[torch.float32])
|
||||||
@ -648,7 +666,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
|
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
kwargs = optim_input.kwargs
|
kwargs = optim_input.kwargs
|
||||||
_make_radam_single_tensor_non_capturable(optim_cls, optim_input.kwargs)
|
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
|
||||||
|
and not kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
# params will decay even if grads are empty if weight_decay != 0,
|
# params will decay even if grads are empty if weight_decay != 0,
|
||||||
# and capturable doesn't work for CPU tensors
|
# and capturable doesn't work for CPU tensors
|
||||||
@ -657,7 +678,7 @@ class TestOptimRenewed(TestCase):
|
|||||||
|
|
||||||
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
|
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
|
||||||
if optim_cls.__name__ == "AdamW":
|
if optim_cls.__name__ == "AdamW":
|
||||||
kwargs["lr"] = torch.tensor(1e-4) if isinstance(kwargs.get("lr", 1e-4), torch.Tensor) else 1e-4
|
kwargs["lr"] = torch.tensor(1e-5) if isinstance(kwargs.get("lr", 1e-5), torch.Tensor) else 1e-5
|
||||||
|
|
||||||
if kwargs.get("differentiable", False):
|
if kwargs.get("differentiable", False):
|
||||||
params = [param.clone()]
|
params = [param.clone()]
|
||||||
@ -684,6 +705,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||||
params = [Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) for _ in range(2)]
|
params = [Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)) for _ in range(2)]
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||||
optimizer.__repr__()
|
optimizer.__repr__()
|
||||||
|
|
||||||
@ -706,6 +731,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||||
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
|
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
|
||||||
|
|
||||||
@ -749,6 +778,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
model = torch.nn.Sequential(
|
model = torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(4, 2, 1, stride=2),
|
torch.nn.Conv2d(4, 2, 1, stride=2),
|
||||||
@ -811,9 +844,8 @@ class TestOptimRenewed(TestCase):
|
|||||||
|
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
kwargs = optim_input.kwargs
|
kwargs = optim_input.kwargs
|
||||||
# See https://github.com/pytorch/pytorch/issues/117836 for Adamax
|
if (optim_info.only_supports_capturable_on_foreach and kwargs.get("capturable", False)
|
||||||
# See https://github.com/pytorch/pytorch/issues/118230 for RAdam
|
and not kwargs.get("foreach", False)):
|
||||||
if optim_cls.__name__ in ["Adamax", "RAdam"] and kwargs.get("capturable", False) and not kwargs.get("foreach", False):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||||
@ -843,6 +875,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
|
return lbfgs_loss if optim_cls.__name__ == "LBFGS" else None
|
||||||
|
|
||||||
for optim_input in cpu_optim_inputs:
|
for optim_input in cpu_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
|
params = [Parameter(torch.randn(2, 3, device="cpu", dtype=dtype)) for _ in range(2)]
|
||||||
for p in params:
|
for p in params:
|
||||||
p.grad = torch.randn_like(p)
|
p.grad = torch.randn_like(p)
|
||||||
@ -906,6 +942,10 @@ class TestOptimRenewed(TestCase):
|
|||||||
return {k for k in obj.__dict__ if not k.startswith("_")}
|
return {k for k in obj.__dict__ if not k.startswith("_")}
|
||||||
|
|
||||||
for optim_input in all_optim_inputs:
|
for optim_input in all_optim_inputs:
|
||||||
|
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||||
|
and not optim_input.kwargs.get("foreach", False)):
|
||||||
|
continue
|
||||||
|
|
||||||
optimizer = optim_cls(params, **optim_input.kwargs)
|
optimizer = optim_cls(params, **optim_input.kwargs)
|
||||||
|
|
||||||
# Make some state
|
# Make some state
|
||||||
|
@ -1568,7 +1568,7 @@ class TorchPatcher:
|
|||||||
}
|
}
|
||||||
|
|
||||||
excluded_single_tensor = {
|
excluded_single_tensor = {
|
||||||
radam, # https://github.com/pytorch/pytorch/issues/117807
|
radam, # https://github.com/pytorch/pytorch/issues/118230
|
||||||
}
|
}
|
||||||
|
|
||||||
for opt_mod in optimizer_modules:
|
for opt_mod in optimizer_modules:
|
||||||
|
@ -74,7 +74,10 @@ class Adam(Optimizer):
|
|||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
p_state = self.state.get(p, [])
|
p_state = self.state.get(p, [])
|
||||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
|
||||||
|
if group['capturable'] or group['fused']
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||||
|
|
||||||
def _init_group(
|
def _init_group(
|
||||||
self,
|
self,
|
||||||
|
@ -34,6 +34,9 @@ class Adamax(Optimizer):
|
|||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
if foreach is False and capturable:
|
||||||
|
raise ValueError("Capturable not supported with single tensor Adamax")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
@ -53,13 +56,12 @@ class Adamax(Optimizer):
|
|||||||
group.setdefault("maximize", False)
|
group.setdefault("maximize", False)
|
||||||
group.setdefault("differentiable", False)
|
group.setdefault("differentiable", False)
|
||||||
group.setdefault("capturable", False)
|
group.setdefault("capturable", False)
|
||||||
state_values = list(self.state.values())
|
for p in group["params"]:
|
||||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
p_state = self.state.get(p, [])
|
||||||
state_values[0]["step"]
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
)
|
step_val = float(p_state["step"])
|
||||||
if not step_is_tensor:
|
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device) if group['capturable']
|
||||||
for s in state_values:
|
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||||
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
|
|
||||||
|
|
||||||
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
|
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps):
|
||||||
has_complex = False
|
has_complex = False
|
||||||
@ -265,6 +267,8 @@ def _single_tensor_adamax(
|
|||||||
capturable: bool,
|
capturable: bool,
|
||||||
has_complex: bool,
|
has_complex: bool,
|
||||||
):
|
):
|
||||||
|
if capturable:
|
||||||
|
raise RuntimeError("capturable is not supported for single tensor Adamax (when foreach=False)")
|
||||||
|
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
grad = grads[i]
|
grad = grads[i]
|
||||||
@ -272,6 +276,7 @@ def _single_tensor_adamax(
|
|||||||
exp_avg = exp_avgs[i]
|
exp_avg = exp_avgs[i]
|
||||||
exp_inf = exp_infs[i]
|
exp_inf = exp_infs[i]
|
||||||
step_t = state_steps[i]
|
step_t = state_steps[i]
|
||||||
|
|
||||||
# update step
|
# update step
|
||||||
step_t += 1
|
step_t += 1
|
||||||
|
|
||||||
@ -328,6 +333,11 @@ def _multi_tensor_adamax(
|
|||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if (not torch._utils.is_compiling() and capturable
|
||||||
|
and not all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps))):
|
||||||
|
raise RuntimeError("If capturable=True, params and state_steps must be CUDA tensors.")
|
||||||
|
|
||||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
|
||||||
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
|
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
|
||||||
if has_complex:
|
if has_complex:
|
||||||
|
@ -83,7 +83,10 @@ class AdamW(Optimizer):
|
|||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
p_state = self.state.get(p, [])
|
p_state = self.state.get(p, [])
|
||||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(is_fused=fused), device=p.device)
|
||||||
|
if group['capturable'] or group['fused']
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||||
|
|
||||||
def _init_group(
|
def _init_group(
|
||||||
self,
|
self,
|
||||||
|
@ -34,7 +34,7 @@ class ASGD(Optimizer):
|
|||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
if foreach is False and capturable:
|
if foreach is False and capturable and not is_compiling():
|
||||||
raise ValueError("Capturable not supported with single tensor ASGD")
|
raise ValueError("Capturable not supported with single tensor ASGD")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
@ -57,25 +57,20 @@ class ASGD(Optimizer):
|
|||||||
group.setdefault("maximize", False)
|
group.setdefault("maximize", False)
|
||||||
group.setdefault("differentiable", False)
|
group.setdefault("differentiable", False)
|
||||||
group.setdefault("capturable", False)
|
group.setdefault("capturable", False)
|
||||||
state_values = list(self.state.values())
|
for p in group["params"]:
|
||||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
p_state = self.state.get(p, [])
|
||||||
state_values[0]["step"]
|
if len(p_state) != 0:
|
||||||
)
|
if not torch.is_tensor(p_state['step']):
|
||||||
if not step_is_tensor:
|
step_val = float(p_state["step"])
|
||||||
for s in state_values:
|
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||||
s["step"] = torch.tensor(float(s["step"]), dtype=_get_scalar_dtype())
|
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||||
eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
if not torch.is_tensor(p_state["eta"]):
|
||||||
state_values[0]["eta"]
|
p_state["eta"] = (torch.tensor(p_state["eta"], dtype=_get_scalar_dtype(), device=p.device)
|
||||||
)
|
if group["capturable"] else torch.tensor(p_state["eta"], dtype=_get_scalar_dtype()))
|
||||||
if not eta_is_tensor:
|
if not torch.is_tensor(p_state["mu"]):
|
||||||
for s in state_values:
|
p_state["mu"] = (torch.tensor(p_state["mu"], dtype=_get_scalar_dtype(), device=p.device)
|
||||||
s["eta"] = torch.tensor(s["eta"], dtype=_get_scalar_dtype())
|
if group["capturable"] else torch.tensor(p_state["mu"], dtype=_get_scalar_dtype()))
|
||||||
mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
|
||||||
state_values[0]["mu"]
|
|
||||||
)
|
|
||||||
if not mu_is_tensor:
|
|
||||||
for s in state_values:
|
|
||||||
s["mu"] = torch.tensor(float(s["mu"]), dtype=_get_scalar_dtype())
|
|
||||||
|
|
||||||
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
|
def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
|
||||||
has_complex = False
|
has_complex = False
|
||||||
@ -206,8 +201,6 @@ def asgd(
|
|||||||
if foreach and not torch.jit.is_scripting():
|
if foreach and not torch.jit.is_scripting():
|
||||||
func = _multi_tensor_asgd
|
func = _multi_tensor_asgd
|
||||||
else:
|
else:
|
||||||
if capturable and not is_compiling():
|
|
||||||
raise RuntimeError("Capturable not supported with single tensor ASGD")
|
|
||||||
func = _single_tensor_asgd
|
func = _single_tensor_asgd
|
||||||
|
|
||||||
func(
|
func(
|
||||||
@ -247,6 +240,9 @@ def _single_tensor_asgd(
|
|||||||
capturable: bool,
|
capturable: bool,
|
||||||
has_complex: bool,
|
has_complex: bool,
|
||||||
):
|
):
|
||||||
|
if capturable and not is_compiling():
|
||||||
|
raise RuntimeError("capturable is not supported for single tensor ASGD (when foreach=False)")
|
||||||
|
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
grad = grads[i]
|
grad = grads[i]
|
||||||
grad = grad if not maximize else -grad
|
grad = grad if not maximize else -grad
|
||||||
@ -304,12 +300,17 @@ def _multi_tensor_asgd(
|
|||||||
capturable: bool,
|
capturable: bool,
|
||||||
has_complex: bool,
|
has_complex: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert not differentiable, "_foreach ops don't support autograd"
|
assert not differentiable, "_foreach ops don't support autograd"
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
assert all(p.is_cuda and mu.is_cuda and eta.is_cuda and step.is_cuda
|
||||||
|
for p, mu, eta, step in zip(params, mus, etas, state_steps)), \
|
||||||
|
"If capturable=True, params, mu_products, and state_steps must be CUDA tensors."
|
||||||
|
|
||||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
|
||||||
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
|
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
|
||||||
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
|
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
|
||||||
|
@ -37,15 +37,18 @@ class NAdam(Optimizer):
|
|||||||
group.setdefault('capturable', False)
|
group.setdefault('capturable', False)
|
||||||
group.setdefault('differentiable', False)
|
group.setdefault('differentiable', False)
|
||||||
group.setdefault('decoupled_weight_decay', False)
|
group.setdefault('decoupled_weight_decay', False)
|
||||||
state_values = list(self.state.values())
|
for p in group["params"]:
|
||||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
p_state = self.state.get(p, [])
|
||||||
if not step_is_tensor:
|
if len(p_state) != 0:
|
||||||
for s in state_values:
|
if not torch.is_tensor(p_state['step']):
|
||||||
s['step'] = torch.tensor(float(s['step']), dtype=_get_scalar_dtype())
|
step_val = float(p_state["step"])
|
||||||
mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
|
p_state["step"] = (torch.tensor(step_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||||
if not mu_product_is_tensor:
|
if group['capturable'] else torch.tensor(step_val, dtype=_get_scalar_dtype()))
|
||||||
for s in state_values:
|
if not torch.is_tensor(p_state['mu_product']):
|
||||||
s['mu_product'] = torch.tensor(s['mu_product'], dtype=_get_scalar_dtype())
|
mu_prod_val = p_state["mu_product"]
|
||||||
|
p_state["mu_product"] = (torch.tensor(mu_prod_val, dtype=_get_scalar_dtype(), device=p.device)
|
||||||
|
if group['capturable'] else torch.tensor(mu_prod_val, dtype=_get_scalar_dtype()))
|
||||||
|
|
||||||
|
|
||||||
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
|
def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
|
||||||
has_complex = False
|
has_complex = False
|
||||||
|
@ -44,6 +44,10 @@ class RAdam(Optimizer):
|
|||||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
if foreach is False and capturable:
|
||||||
|
raise ValueError("Capturable not supported with single tensor RAdam")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
@ -208,7 +212,7 @@ RAdam.__doc__ = r"""Implements RAdam algorithm.
|
|||||||
decay as in AdamW to obtain RAdamW (default: False)
|
decay as in AdamW to obtain RAdamW (default: False)
|
||||||
{_foreach_doc}
|
{_foreach_doc}
|
||||||
{_differentiable_doc}
|
{_differentiable_doc}
|
||||||
{_capturable_doc}
|
{_capturable_doc} For RAdam, capturable is only supported when foreach=True.
|
||||||
|
|
||||||
.. _On the variance of the adaptive learning rate and beyond:
|
.. _On the variance of the adaptive learning rate and beyond:
|
||||||
https://arxiv.org/abs/1908.03265
|
https://arxiv.org/abs/1908.03265
|
||||||
@ -297,7 +301,7 @@ def _single_tensor_radam(
|
|||||||
has_complex: bool,
|
has_complex: bool,
|
||||||
):
|
):
|
||||||
if capturable:
|
if capturable:
|
||||||
raise RuntimeError("capturable is not supported for single tensor radam")
|
raise RuntimeError("capturable is not supported for single tensor RAdam (when foreach=False)")
|
||||||
|
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
grad = grads[i]
|
grad = grads[i]
|
||||||
|
@ -114,6 +114,8 @@ class OptimizerInfo:
|
|||||||
supports_param_groups: bool = True,
|
supports_param_groups: bool = True,
|
||||||
# whether the optimizer supports parameters on multiple devices
|
# whether the optimizer supports parameters on multiple devices
|
||||||
supports_multiple_devices: bool = True,
|
supports_multiple_devices: bool = True,
|
||||||
|
# whether the optimizer ONLY supports capturable on foreach vs. both foreach and forloop
|
||||||
|
only_supports_capturable_on_foreach: bool = False,
|
||||||
skips=(), # Indicates which tests to skip
|
skips=(), # Indicates which tests to skip
|
||||||
decorators=None, # Additional decorators to apply to generated tests
|
decorators=None, # Additional decorators to apply to generated tests
|
||||||
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
||||||
@ -126,6 +128,7 @@ class OptimizerInfo:
|
|||||||
self.step_requires_closure = step_requires_closure
|
self.step_requires_closure = step_requires_closure
|
||||||
self.supports_param_groups = supports_param_groups
|
self.supports_param_groups = supports_param_groups
|
||||||
self.supports_multiple_devices = supports_multiple_devices
|
self.supports_multiple_devices = supports_multiple_devices
|
||||||
|
self.only_supports_capturable_on_foreach = only_supports_capturable_on_foreach
|
||||||
self.decorators = (
|
self.decorators = (
|
||||||
*(decorators if decorators else []),
|
*(decorators if decorators else []),
|
||||||
*(skips if skips else []),
|
*(skips if skips else []),
|
||||||
@ -262,7 +265,7 @@ def get_error_inputs_for_all_optims(device, dtype):
|
|||||||
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
|
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_adadelta(device=None):
|
def optim_inputs_func_adadelta(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||||
@ -297,7 +300,7 @@ def optim_error_inputs_func_adadelta(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_adagrad(device=None):
|
def optim_inputs_func_adagrad(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
@ -341,7 +344,7 @@ def optim_error_inputs_func_adagrad(device, dtype):
|
|||||||
|
|
||||||
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
|
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
|
||||||
# with all implementation code paths...
|
# with all implementation code paths...
|
||||||
def optim_inputs_func_adam(device=None):
|
def optim_inputs_func_adam(device):
|
||||||
cuda_supported_configs = [
|
cuda_supported_configs = [
|
||||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
@ -370,7 +373,7 @@ def optim_inputs_func_adam(device=None):
|
|||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
|
params=None, kwargs={"weight_decay": 0.1, "amsgrad": True}, desc="amsgrad"
|
||||||
),
|
),
|
||||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||||
|
|
||||||
|
|
||||||
def optim_error_inputs_func_adam(device, dtype):
|
def optim_error_inputs_func_adam(device, dtype):
|
||||||
@ -405,7 +408,7 @@ def optim_error_inputs_func_adam(device, dtype):
|
|||||||
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
if str(device) == "cuda":
|
if "cuda" in str(device):
|
||||||
sample_tensor = torch.empty((), device=device, dtype=dtype)
|
sample_tensor = torch.empty((), device=device, dtype=dtype)
|
||||||
error_inputs += [
|
error_inputs += [
|
||||||
ErrorOptimizerInput(
|
ErrorOptimizerInput(
|
||||||
@ -430,7 +433,7 @@ def optim_error_inputs_func_adam(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_adamax(device=None):
|
def optim_inputs_func_adamax(device):
|
||||||
cuda_supported_configs = [
|
cuda_supported_configs = [
|
||||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
@ -461,11 +464,23 @@ def optim_inputs_func_adamax(device=None):
|
|||||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||||
desc="maximize",
|
desc="maximize",
|
||||||
),
|
),
|
||||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||||
|
|
||||||
|
|
||||||
def optim_error_inputs_func_adamax(device, dtype):
|
def optim_error_inputs_func_adamax(device, dtype):
|
||||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||||
|
if "cuda" in str(device):
|
||||||
|
error_inputs += [
|
||||||
|
ErrorOptimizerInput(
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs=dict(foreach=False, capturable=True),
|
||||||
|
desc="single tensor capturable not supported",
|
||||||
|
),
|
||||||
|
error_type=ValueError,
|
||||||
|
error_regex="Capturable not supported with single tensor Adamax",
|
||||||
|
)
|
||||||
|
]
|
||||||
if str(device) == "cpu":
|
if str(device) == "cpu":
|
||||||
error_inputs += [
|
error_inputs += [
|
||||||
ErrorOptimizerInput(
|
ErrorOptimizerInput(
|
||||||
@ -481,15 +496,33 @@ def optim_error_inputs_func_adamax(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_adamw(device=None):
|
def optim_inputs_func_adamw(device):
|
||||||
return optim_inputs_func_adam(device=device)
|
return optim_inputs_func_adam(device)
|
||||||
|
|
||||||
|
|
||||||
def optim_error_inputs_func_adamw(device, dtype):
|
def optim_error_inputs_func_adamw(device, dtype):
|
||||||
return optim_error_inputs_func_adam(device, dtype)
|
return optim_error_inputs_func_adam(device, dtype)
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_asgd(device=None):
|
def optim_inputs_func_asgd(device):
|
||||||
|
cuda_supported_configs = [
|
||||||
|
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs={"maximize": True, "capturable": True},
|
||||||
|
desc="maximize, capturable",
|
||||||
|
),
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs={"weight_decay": 0.1, "capturable": True},
|
||||||
|
desc="weight_decay, capturable",
|
||||||
|
),
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
|
||||||
|
desc="maximize, weight_decay, capturable",
|
||||||
|
),
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||||||
@ -501,13 +534,25 @@ def optim_inputs_func_asgd(device=None):
|
|||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
params=None,
|
params=None,
|
||||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||||
desc="maximize",
|
desc="maximize, nonzero weight_decay",
|
||||||
),
|
),
|
||||||
]
|
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||||
|
|
||||||
|
|
||||||
def optim_error_inputs_func_asgd(device, dtype):
|
def optim_error_inputs_func_asgd(device, dtype):
|
||||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||||
|
if "cuda" in str(device):
|
||||||
|
error_inputs += [
|
||||||
|
ErrorOptimizerInput(
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs=dict(foreach=False, capturable=True),
|
||||||
|
desc="single tensor capturable not supported",
|
||||||
|
),
|
||||||
|
error_type=ValueError,
|
||||||
|
error_regex="Capturable not supported with single tensor ASGD",
|
||||||
|
)
|
||||||
|
]
|
||||||
if str(device) == "cpu":
|
if str(device) == "cpu":
|
||||||
error_inputs += [
|
error_inputs += [
|
||||||
ErrorOptimizerInput(
|
ErrorOptimizerInput(
|
||||||
@ -523,7 +568,7 @@ def optim_error_inputs_func_asgd(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_lbfgs(device=None):
|
def optim_inputs_func_lbfgs(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||||||
@ -544,7 +589,7 @@ def optim_error_inputs_func_lbfgs(device, dtype):
|
|||||||
|
|
||||||
|
|
||||||
# Weird story bro, NAdam and RAdam do not have maximize.
|
# Weird story bro, NAdam and RAdam do not have maximize.
|
||||||
def optim_inputs_func_nadam(device=None):
|
def optim_inputs_func_nadam(device):
|
||||||
cuda_supported_configs = [
|
cuda_supported_configs = [
|
||||||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
@ -585,7 +630,7 @@ def optim_inputs_func_nadam(device=None):
|
|||||||
},
|
},
|
||||||
desc="decoupled_weight_decay",
|
desc="decoupled_weight_decay",
|
||||||
),
|
),
|
||||||
] + (cuda_supported_configs if str(device) == "cuda" else [])
|
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||||
|
|
||||||
|
|
||||||
def optim_error_inputs_func_nadam(device, dtype):
|
def optim_error_inputs_func_nadam(device, dtype):
|
||||||
@ -653,6 +698,18 @@ def optim_inputs_func_radam(device=None):
|
|||||||
|
|
||||||
def optim_error_inputs_func_radam(device, dtype):
|
def optim_error_inputs_func_radam(device, dtype):
|
||||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||||
|
if "cuda" in str(device):
|
||||||
|
error_inputs += [
|
||||||
|
ErrorOptimizerInput(
|
||||||
|
OptimizerInput(
|
||||||
|
params=None,
|
||||||
|
kwargs=dict(foreach=False, capturable=True),
|
||||||
|
desc="single tensor capturable not supported",
|
||||||
|
),
|
||||||
|
error_type=ValueError,
|
||||||
|
error_regex="Capturable not supported with single tensor RAdam",
|
||||||
|
),
|
||||||
|
]
|
||||||
if str(device) == "cpu":
|
if str(device) == "cpu":
|
||||||
error_inputs += [
|
error_inputs += [
|
||||||
ErrorOptimizerInput(
|
ErrorOptimizerInput(
|
||||||
@ -677,7 +734,7 @@ def optim_error_inputs_func_radam(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_rmsprop(device=None):
|
def optim_inputs_func_rmsprop(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
||||||
@ -724,7 +781,7 @@ def optim_error_inputs_func_rmsprop(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_rprop(device=None):
|
def optim_inputs_func_rprop(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
|
||||||
@ -757,7 +814,7 @@ def optim_error_inputs_func_rprop(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_sgd(device=None):
|
def optim_inputs_func_sgd(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
||||||
@ -802,7 +859,7 @@ def optim_error_inputs_func_sgd(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
def optim_inputs_func_sparseadam(device=None):
|
def optim_inputs_func_sparseadam(device):
|
||||||
return [
|
return [
|
||||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||||
OptimizerInput(
|
OptimizerInput(
|
||||||
@ -879,6 +936,14 @@ def optim_error_inputs_func_sparseadam(device, dtype):
|
|||||||
return error_inputs
|
return error_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def _get_device_type(device: Union[str, torch.device]) -> str:
|
||||||
|
# Returns the device type as a string, e.g., "cpu" or "cuda"
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
device = str(device.type)
|
||||||
|
assert isinstance(device, str)
|
||||||
|
return device.split(":")[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_optim_inputs_including_global_cliquey_kwargs(
|
def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||||
device, dtype, optim_info, skip=()
|
device, dtype, optim_info, skip=()
|
||||||
) -> List[OptimizerInput]:
|
) -> List[OptimizerInput]:
|
||||||
@ -897,14 +962,20 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
|||||||
x in ["foreach", "fused", "differentiable"] for x in skip
|
x in ["foreach", "fused", "differentiable"] for x in skip
|
||||||
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
), "skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||||||
|
|
||||||
optim_inputs = optim_info.optim_inputs_func(device=device)
|
optim_inputs = optim_info.optim_inputs_func(device)
|
||||||
|
|
||||||
supported_impls = tuple(
|
supported_impls = tuple(
|
||||||
x
|
x
|
||||||
for x in optim_info.supported_impls
|
for x in optim_info.supported_impls
|
||||||
if x not in skip
|
if x not in skip
|
||||||
and (str(device) in _get_fused_kernels_supported_devices() or x != "fused")
|
and (
|
||||||
and (str(device) in _get_foreach_kernels_supported_devices() or x != "foreach")
|
_get_device_type(device) in _get_fused_kernels_supported_devices()
|
||||||
|
or x != "fused"
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
_get_device_type(device) in _get_foreach_kernels_supported_devices()
|
||||||
|
or x != "foreach"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
all_optim_inputs = []
|
all_optim_inputs = []
|
||||||
@ -1131,6 +1202,7 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
optim_inputs_func=optim_inputs_func_adamax,
|
optim_inputs_func=optim_inputs_func_adamax,
|
||||||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||||||
supported_impls=("foreach", "differentiable"),
|
supported_impls=("foreach", "differentiable"),
|
||||||
|
only_supports_capturable_on_foreach=True, # Remove this line when #117836 is done!
|
||||||
skips=(
|
skips=(
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||||
@ -1197,6 +1269,62 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
"TestOptimRenewed",
|
"TestOptimRenewed",
|
||||||
"test_deepcopy_copies_all_public_attrs",
|
"test_deepcopy_copies_all_public_attrs",
|
||||||
),
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"cpu fails due to #115607; both devices fail cuz #117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_can_load_older_state_dict",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_step_is_noop_for_zero_grads",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_step_is_noop_when_params_have_no_grad",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_load_nontensor_step",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_param_groups_weight_decay",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_param_groups_lr",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_state_dict_with_cuda_params",
|
||||||
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/117836"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_mixed_device_dtype",
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
OptimizerInfo(
|
OptimizerInfo(
|
||||||
@ -1267,6 +1395,7 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
optim_inputs_func=optim_inputs_func_asgd,
|
optim_inputs_func=optim_inputs_func_asgd,
|
||||||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||||||
supported_impls=("foreach", "differentiable"),
|
supported_impls=("foreach", "differentiable"),
|
||||||
|
only_supports_capturable_on_foreach=True, # Remove this line when #116052 is done!
|
||||||
skips=(
|
skips=(
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
skipIfTorchDynamo(
|
skipIfTorchDynamo(
|
||||||
@ -1455,6 +1584,7 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
optim_inputs_func=optim_inputs_func_radam,
|
optim_inputs_func=optim_inputs_func_radam,
|
||||||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||||||
supported_impls=("foreach", "differentiable"),
|
supported_impls=("foreach", "differentiable"),
|
||||||
|
only_supports_capturable_on_foreach=True, # Remove this line when #118230 is done!
|
||||||
skips=(
|
skips=(
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
skipIfTorchDynamo(
|
skipIfTorchDynamo(
|
||||||
@ -1540,6 +1670,13 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
"TestOptimRenewed",
|
"TestOptimRenewed",
|
||||||
"test_complex",
|
"test_complex",
|
||||||
),
|
),
|
||||||
|
DecorateInfo(
|
||||||
|
skipIfTorchDynamo(
|
||||||
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||||
|
),
|
||||||
|
"TestOptimRenewed",
|
||||||
|
"test_step_is_noop_for_zero_grads",
|
||||||
|
),
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
skipIfTorchDynamo(
|
skipIfTorchDynamo(
|
||||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||||
@ -1568,13 +1705,6 @@ optim_db: List[OptimizerInfo] = [
|
|||||||
"TestOptimRenewed",
|
"TestOptimRenewed",
|
||||||
"test_param_groups_lr",
|
"test_param_groups_lr",
|
||||||
),
|
),
|
||||||
DecorateInfo(
|
|
||||||
skipIfTorchDynamo(
|
|
||||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
|
||||||
),
|
|
||||||
"TestOptimRenewed",
|
|
||||||
"test_step_is_noop_for_zero_grads",
|
|
||||||
),
|
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
skipIfTorchDynamo(
|
skipIfTorchDynamo(
|
||||||
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
"Should be fixed by https://github.com/pytorch/pytorch/issues/118230"
|
||||||
|
Reference in New Issue
Block a user