Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127)

This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure.

Lintrunner passed:
```
$ lintrunner test/test_cuda.py
ok No lint issues.
```
Tests passed:
```
>python test_cuda.py -k test_graph_optims
Ran 19 tests in 7.463s

OK (skipped=9)

>python test_cuda.py -k test_graph_scaling_fused_optimizers
Ran 6 tests in 2.800s

OK (skipped=3)
```
Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers.

I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called:
```
{'foreach': False, 'maximize': True, 'weight_decay': 0}
{ 'foreach': True, 'maximize': True, 'weight_decay': 0}
```
I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned:
```
{'amsgrad': True}
```

Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127
Approved by: https://github.com/janeyx99
This commit is contained in:
jayanth domalapalli
2024-05-20 06:20:45 +00:00
committed by PyTorch MergeBot
parent 5fb11cda4f
commit cf35a591b9
3 changed files with 225 additions and 234 deletions

View File

@ -123,6 +123,8 @@ class OptimizerInfo:
supported_impls: Tuple[str] = ("foreach", "differentiable"),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optimizer constructor supports passing in capturable as a kwarg
has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
@ -147,6 +149,7 @@ class OptimizerInfo:
self.scheduler_inputs = scheduler_inputs
self.supported_impls = supported_impls
self.supports_sparse = supports_sparse
self.has_capturable_arg = has_capturable_arg
self.metadata_for_sparse = metadata_for_sparse
self.only_supports_sparse_grads = only_supports_sparse_grads
self.supports_complex = supports_complex
@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
desc="maximize, weight_decay",
),
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
kwargs={"maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])
@ -683,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
desc="weight_decay, momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
},
desc="decoupled_weight_decay",
@ -818,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
"weight_decay": 0.1,
},
desc="maximize, weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
@ -836,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None):
"momentum": 0.1,
"maximize": True,
},
desc="maximize",
desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])
@ -907,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
@ -916,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "weight_decay": 0.1},
desc="non-zero weight_decay",
desc="weight_decay w/ momentum",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
]
@ -1097,6 +1130,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adadelta,
optim_error_inputs_func=optim_error_inputs_func_adadelta,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1232,6 +1266,7 @@ optim_db: List[OptimizerInfo] = [
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
has_capturable_arg=True,
decorators=(
# Expected floating point error between fused and compiled forloop
DecorateInfo(
@ -1298,6 +1333,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adamax,
optim_error_inputs_func=optim_error_inputs_func_adamax,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1348,6 +1384,7 @@ optim_db: List[OptimizerInfo] = [
optim_error_inputs_func=optim_error_inputs_func_adamw,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
has_capturable_arg=True,
decorators=(
# Expected error between compiled forloop and fused optimizers
DecorateInfo(
@ -1414,6 +1451,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_asgd,
optim_error_inputs_func=optim_error_inputs_func_asgd,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1506,6 +1544,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_nadam,
optim_error_inputs_func=optim_error_inputs_func_nadam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1561,6 +1600,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_radam,
optim_error_inputs_func=optim_error_inputs_func_radam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1606,6 +1646,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_rmsprop,
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1655,6 +1696,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_rprop,
optim_error_inputs_func=optim_error_inputs_func_rprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117