mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127)
This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure. Lintrunner passed: ``` $ lintrunner test/test_cuda.py ok No lint issues. ``` Tests passed: ``` >python test_cuda.py -k test_graph_optims Ran 19 tests in 7.463s OK (skipped=9) >python test_cuda.py -k test_graph_scaling_fused_optimizers Ran 6 tests in 2.800s OK (skipped=3) ``` Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers. I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called: ``` {'foreach': False, 'maximize': True, 'weight_decay': 0} { 'foreach': True, 'maximize': True, 'weight_decay': 0} ``` I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned: ``` {'amsgrad': True} ``` Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5fb11cda4f
commit
cf35a591b9
@ -123,6 +123,8 @@ class OptimizerInfo:
|
||||
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
||||
# the optim supports passing in sparse gradients as well as dense grads
|
||||
supports_sparse: bool = False,
|
||||
# the optimizer constructor supports passing in capturable as a kwarg
|
||||
has_capturable_arg: bool = False,
|
||||
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
||||
only_supports_sparse_grads: bool = False,
|
||||
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
|
||||
@ -147,6 +149,7 @@ class OptimizerInfo:
|
||||
self.scheduler_inputs = scheduler_inputs
|
||||
self.supported_impls = supported_impls
|
||||
self.supports_sparse = supports_sparse
|
||||
self.has_capturable_arg = has_capturable_arg
|
||||
self.metadata_for_sparse = metadata_for_sparse
|
||||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||||
self.supports_complex = supports_complex
|
||||
@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
|
||||
@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None):
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
kwargs={"maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
|
||||
@ -683,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None):
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||||
kwargs={
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
desc="weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||||
desc="weight_decay, momentum_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"weight_decay": 0.1,
|
||||
"momentum_decay": 6e-3,
|
||||
"decoupled_weight_decay": True,
|
||||
},
|
||||
desc="decoupled_weight_decay",
|
||||
@ -818,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"maximize": True,
|
||||
},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "centered": True},
|
||||
desc="centered",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={
|
||||
"maximize": True,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
desc="maximize, weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
|
||||
@ -836,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None):
|
||||
"momentum": 0.1,
|
||||
"maximize": True,
|
||||
},
|
||||
desc="maximize",
|
||||
desc="maximize, centered, weight_decay, w/ momentum",
|
||||
),
|
||||
] + (cuda_supported_configs if "cuda" in str(device) else [])
|
||||
|
||||
@ -907,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
|
||||
),
|
||||
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "dampening": 0.5},
|
||||
@ -916,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None):
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "weight_decay": 0.1},
|
||||
desc="non-zero weight_decay",
|
||||
desc="weight_decay w/ momentum",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
|
||||
desc="nesterov",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -1097,6 +1130,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adadelta,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adadelta,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1232,6 +1266,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
has_capturable_arg=True,
|
||||
decorators=(
|
||||
# Expected floating point error between fused and compiled forloop
|
||||
DecorateInfo(
|
||||
@ -1298,6 +1333,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adamax,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1348,6 +1384,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
supports_fused_on=("cpu", "cuda"),
|
||||
has_capturable_arg=True,
|
||||
decorators=(
|
||||
# Expected error between compiled forloop and fused optimizers
|
||||
DecorateInfo(
|
||||
@ -1414,6 +1451,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_asgd,
|
||||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1506,6 +1544,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_nadam,
|
||||
optim_error_inputs_func=optim_error_inputs_func_nadam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1561,6 +1600,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_radam,
|
||||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
@ -1606,6 +1646,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_rmsprop,
|
||||
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
|
||||
@ -1655,6 +1696,7 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_rprop,
|
||||
optim_error_inputs_func=optim_error_inputs_func_rprop,
|
||||
supported_impls=("foreach", "differentiable"),
|
||||
has_capturable_arg=True,
|
||||
skips=(
|
||||
DecorateInfo(
|
||||
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117
|
||||
|
||||
Reference in New Issue
Block a user