Updated test_graph_optims and test_graph_scaling_fused_optimizers to use new OptimizerInfo infrastructure (#125127)

This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure.

Lintrunner passed:
```
$ lintrunner test/test_cuda.py
ok No lint issues.
```
Tests passed:
```
>python test_cuda.py -k test_graph_optims
Ran 19 tests in 7.463s

OK (skipped=9)

>python test_cuda.py -k test_graph_scaling_fused_optimizers
Ran 6 tests in 2.800s

OK (skipped=3)
```
Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers.

I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called:
```
{'foreach': False, 'maximize': True, 'weight_decay': 0}
{ 'foreach': True, 'maximize': True, 'weight_decay': 0}
```
I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned:
```
{'amsgrad': True}
```

Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127
Approved by: https://github.com/janeyx99
This commit is contained in:
jayanth domalapalli
2024-05-20 06:20:45 +00:00
committed by PyTorch MergeBot
parent 5fb11cda4f
commit cf35a591b9
3 changed files with 225 additions and 234 deletions

View File

@ -136,6 +136,8 @@ KERNEL_COUNT_OVERRIDES = {
"test_sgd_momentum_foreach_cuda": 5,
"test_sgd_weight_decay_maximize_cuda": 4,
"test_sgd_weight_decay_maximize_cpu": 4,
"test_sgd_weight_decay_cpu": 4,
"test_sgd_weight_decay_cuda": 4,
"test_sgd_momentum_weight_decay_foreach_cuda": 2,
"test_sgd_momentum_nesterov_weight_decay_foreach_cuda": 2,
"test_sgd_cuda": 4,

View File

@ -37,7 +37,11 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.common_optimizers import optim_db, optims
from torch.testing._internal.common_optimizers import (
_get_optim_inputs_including_global_cliquey_kwargs,
optim_db,
optims,
)
from torch.testing._internal.common_utils import (
freeze_rng_state,
gcIfJetson,
@ -3200,111 +3204,6 @@ exit(2)
for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graph_optims(self):
# Needs generalization if we want to extend this test to non-Adam-like optimizers.
cases = (
[
(
optimizer_ctor,
{
"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"decoupled_weight_decay": decoupled_weight_decay,
"weight_decay": weight_decay,
},
)
for optimizer_ctor, foreach, decoupled_weight_decay, weight_decay in product(
(
torch.optim.NAdam,
torch.optim.RAdam,
),
(
False,
True,
),
(
False,
True,
),
(
0.0,
0.1,
),
)
]
+ [
(
torch.optim.Rprop,
{"lr": 0.1, "foreach": foreach, "maximize": maximize},
)
for foreach, maximize in product(
(
False,
True,
),
(
False,
True,
),
)
]
+ [
(
optimizer_ctor,
{
"lr": 0.1,
"betas": (0.8, 0.7),
"foreach": foreach,
"amsgrad": amsgrad,
},
)
for optimizer_ctor, foreach, amsgrad in product(
(torch.optim.Adam, torch.optim.AdamW),
(False, True),
(False, True),
)
]
+ [
(
optimizer_ctor,
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
)
for optimizer_ctor, amsgrad in product(
(torch.optim.Adam, torch.optim.AdamW), (False, True)
)
]
+ [
(
optimizer_ctor,
{
"lr": 0.1,
"foreach": foreach,
"maximize": maximize,
"weight_decay": weight_decay,
},
)
for optimizer_ctor, foreach, maximize, weight_decay in product(
(
torch.optim.Adamax,
torch.optim.ASGD,
torch.optim.Adadelta,
torch.optim.RMSprop,
),
(False, True),
(False, True),
(0, 0.1),
)
]
)
for optimizer_ctor, kwargs in cases:
with self.subTest(optimizer_ctor=optimizer_ctor, kwargs=kwargs):
self._test_graphed_optimizer(3, 2, optimizer_ctor, kwargs)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@ -3376,123 +3275,6 @@ exit(2)
self.assertEqual(ref_p1, param1)
self.assertEqual(ref_p2, param2)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
def test_graph_scaling_fused_optimizers(self):
cases = [
(
optimizer_ctor,
{"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad},
)
for optimizer_ctor, amsgrad in product(
(torch.optim.Adam, torch.optim.AdamW), (False, True)
)
] + list(
product(
(torch.optim.SGD,),
[
{
"lr": 0.1,
"momentum": 0.0,
"dampening": d,
"weight_decay": w,
"nesterov": n,
"fused": True,
}
for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
]
+ [
{
"lr": 0.1,
"momentum": 0.5,
"dampening": d,
"weight_decay": w,
"nesterov": n,
"fused": True,
}
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
],
)
)
steps_warmup = 3
steps_train = 2
for OptClass, kwargs in cases:
has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW)
for actually_do_graphs in (True, False) if has_capturable_arg else (True,):
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]
# `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
grads = [
[torch.randn_like(p) for p in params]
for _ in range(steps_warmup + steps_train)
]
with torch.no_grad():
grads_control = [[g.clone() for g in gs] for gs in grads]
grads_graphed = [[g.clone() for g in gs] for gs in grads]
# Gradient Scaler
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
with torch.no_grad():
scaler_for_control._lazy_init_scale_growth_tracker(
torch.device("cuda")
)
scaler_for_graphed = torch.cuda.amp.GradScaler()
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
with torch.no_grad():
scaler_for_graphed._lazy_init_scale_growth_tracker(
torch.device("cuda")
)
# Control (capturable=False)
if has_capturable_arg:
kwargs["capturable"] = False
opt = OptClass(params_control, **kwargs)
for i in range(steps_warmup + steps_train):
for j, p in enumerate(params_control):
p.grad = grads_control[i][j]
scaler_for_control.step(opt)
scaler_for_control.update()
# capturable=True
if has_capturable_arg:
kwargs["capturable"] = True
opt = OptClass(params_graphed, **kwargs)
for i in range(steps_warmup):
for j, p in enumerate(params_graphed):
p.grad = grads_graphed[i][j]
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
if actually_do_graphs:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
for i in range(steps_train):
if actually_do_graphs:
for j, p in enumerate(params_graphed):
p.grad.copy_(grads_graphed[i + steps_warmup][j])
g.replay()
else:
# Passing capturable=True to the constructor and running without graphs should still be
# numerically correct, even if it's not ideal for performance.
for j, p in enumerate(params_graphed):
p.grad = grads_graphed[i + steps_warmup][j]
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed)
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@ -4698,10 +4480,175 @@ class TestBlockStateAbsorption(TestCase):
self.assertEqual(rc, "False", "Triton was imported when importing torch!")
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCudaOptims(TestCase):
# These tests will be instantiate with instantiate_device_type_tests
# to apply the new OptimizerInfo structure.
@onlyCUDA
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >=5.3 required for graphs"
)
@optims(
[optim for optim in optim_db if optim.has_capturable_arg],
dtypes=[torch.float32],
)
def test_graph_optims(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
steps_warmup = 3
steps_train = 2
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
# lr as a Tensor is not supported when capturable=False and foreach=True for torch.optim.adam
# and torch.optim.adamw
kwargs["lr"] = 0.1
for actually_do_graphs in (True, False):
params = [
torch.randn((i + 5, i + 5), device=device) for i in range(2)
] + [torch.randn((), device=device)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]
grads = [
[torch.randn_like(p) for p in params]
for _ in range(steps_warmup + steps_train)
]
# Control (capturable=False)
kwargs["capturable"] = False
opt = optim_cls(params_control, **kwargs)
for i in range(steps_warmup + steps_train):
for j, p in enumerate(params_control):
p.grad = grads[i][j]
opt.step()
# capturable=True
kwargs["capturable"] = True
opt = optim_cls(params_graphed, **kwargs)
for i in range(steps_warmup):
for j, p in enumerate(params_graphed):
p.grad = grads[i][j]
opt.step()
if actually_do_graphs:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
opt.step()
for i in range(steps_train):
if actually_do_graphs:
for j, p in enumerate(params_graphed):
p.grad.copy_(grads[i + steps_warmup][j])
g.replay()
else:
# Passing capturable=True to the constructor and running without graphs should still be
# numerically correct, even if it's not ideal for performance.
for j, p in enumerate(params_graphed):
p.grad = grads[i + steps_warmup][j]
opt.step()
for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed)
@onlyCUDA
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=[torch.float32],
)
def test_graph_scaling_fused_optimizers(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
steps_warmup = 3
steps_train = 2
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
kwargs = optim_input.kwargs
kwargs["fused"] = True
for actually_do_graphs in (
(True, False) if optim_info.has_capturable_arg else (True,)
):
params = [torch.randn((i + 5, i + 5), device=device) for i in range(2)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]
# `GradScaler` in-place updates gradients thus it's necessary to duplicate gradients.
grads = [
[torch.randn_like(p) for p in params]
for _ in range(steps_warmup + steps_train)
]
with torch.no_grad():
grads_control = [[g.clone() for g in gs] for gs in grads]
grads_graphed = [[g.clone() for g in gs] for gs in grads]
# Gradient Scaler
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
with torch.no_grad():
scaler_for_control._lazy_init_scale_growth_tracker(device)
scaler_for_graphed = torch.cuda.amp.GradScaler()
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
with torch.no_grad():
scaler_for_graphed._lazy_init_scale_growth_tracker(device)
# Control (capturable=False)
if optim_info.has_capturable_arg:
kwargs["capturable"] = False
opt = optim_cls(params_control, **kwargs)
for i in range(steps_warmup + steps_train):
for j, p in enumerate(params_control):
p.grad = grads_control[i][j]
scaler_for_control.step(opt)
scaler_for_control.update()
# capturable=True
if optim_info.has_capturable_arg:
kwargs["capturable"] = True
opt = optim_cls(params_graphed, **kwargs)
for i in range(steps_warmup):
for j, p in enumerate(params_graphed):
p.grad = grads_graphed[i][j]
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
if actually_do_graphs:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
for i in range(steps_train):
if actually_do_graphs:
for j, p in enumerate(params_graphed):
p.grad.copy_(grads_graphed[i + steps_warmup][j])
g.replay()
else:
# Passing capturable=True to the constructor and running without graphs should still be
# numerically correct, even if it's not ideal for performance.
for j, p in enumerate(params_graphed):
p.grad = grads_graphed[i + steps_warmup][j]
scaler_for_graphed.step(opt)
scaler_for_graphed.update()
for p_control, p_graphed in zip(params_control, params_graphed):
self.assertEqual(p_control, p_graphed)
@onlyCUDA
@unittest.skipIf(
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"

View File

@ -123,6 +123,8 @@ class OptimizerInfo:
supported_impls: Tuple[str] = ("foreach", "differentiable"),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optimizer constructor supports passing in capturable as a kwarg
has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
@ -147,6 +149,7 @@ class OptimizerInfo:
self.scheduler_inputs = scheduler_inputs
self.supported_impls = supported_impls
self.supports_sparse = supports_sparse
self.has_capturable_arg = has_capturable_arg
self.metadata_for_sparse = metadata_for_sparse
self.only_supports_sparse_grads = only_supports_sparse_grads
self.supports_complex = supports_complex
@ -311,10 +314,11 @@ def optim_inputs_func_adadelta(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
desc="maximize, weight_decay",
),
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
@ -528,9 +532,14 @@ def optim_inputs_func_adamax(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
kwargs={"maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])
@ -683,14 +692,20 @@ def optim_inputs_func_nadam(device, dtype=None):
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
desc="weight_decay, momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
},
desc="decoupled_weight_decay",
@ -818,11 +833,26 @@ def optim_inputs_func_rmsprop(device, dtype=None):
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
"weight_decay": 0.1,
},
desc="maximize, weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
@ -836,7 +866,7 @@ def optim_inputs_func_rmsprop(device, dtype=None):
"momentum": 0.1,
"maximize": True,
},
desc="maximize",
desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if "cuda" in str(device) else [])
@ -907,7 +937,15 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
@ -916,18 +954,13 @@ def optim_inputs_func_sgd(device, dtype=None):
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "weight_decay": 0.1},
desc="non-zero weight_decay",
desc="weight_decay w/ momentum",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
]
@ -1097,6 +1130,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adadelta,
optim_error_inputs_func=optim_error_inputs_func_adadelta,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1232,6 +1266,7 @@ optim_db: List[OptimizerInfo] = [
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
has_capturable_arg=True,
decorators=(
# Expected floating point error between fused and compiled forloop
DecorateInfo(
@ -1298,6 +1333,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adamax,
optim_error_inputs_func=optim_error_inputs_func_adamax,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1348,6 +1384,7 @@ optim_db: List[OptimizerInfo] = [
optim_error_inputs_func=optim_error_inputs_func_adamw,
supported_impls=("foreach", "differentiable", "fused"),
supports_fused_on=("cpu", "cuda"),
has_capturable_arg=True,
decorators=(
# Expected error between compiled forloop and fused optimizers
DecorateInfo(
@ -1414,6 +1451,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_asgd,
optim_error_inputs_func=optim_error_inputs_func_asgd,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1506,6 +1544,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_nadam,
optim_error_inputs_func=optim_error_inputs_func_nadam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1561,6 +1600,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_radam,
optim_error_inputs_func=optim_error_inputs_func_radam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
@ -1606,6 +1646,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_rmsprop,
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # addcdiv doesn't work for non-contiguous, see #118115
@ -1655,6 +1696,7 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_rprop,
optim_error_inputs_func=optim_error_inputs_func_rprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfMps, # Rprop doesn't update for non-contiguous, see #118117