Migrate param_group testing to OptimizerInfo (#117675)

Today, our param_group testing does the equivalent of pitting weight and bias with different optimizer hyperparams and then check that the overall result is going the right direction based on maximize.

This PR introduces two tests to encompass coverage:
1. For every optimizer input (no differentiable), always force bias to have 0 weight_decay, and then check that the direction is expected. This is basically a replica to today's tests, but is more methodical as the test is a real use case.
2. To ensure that the different groups have distinct behavior, I added another test where lr is basically 0 in default group, and ensure that the param in the default group doesn't move while loss does.

Together, these tests do a better job of testing param groups than today's tests, **though we do lose some flavors**. For example, RMSProp also pits centered=True vs False across the param_groups, Adadelta has a variation on rho, and ASGD has a variation for t0. I don't think this is really a loss, as the previous test was just testing for direction and our new tests test stronger guarantees.

The leftover param group configs are used in conjunction with LRSchedulers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117675
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-01-22 20:29:34 +00:00
committed by PyTorch MergeBot
parent d280b6ae58
commit c6be5d55a5
3 changed files with 115 additions and 194 deletions

View File

@ -325,9 +325,6 @@ class TestOptim(TestCase):
def _build_params_dict(self, weight, bias, **kwargs):
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
def _build_params_dict_single(self, weight, bias, **kwargs):
return [dict(params=bias, **kwargs)]
def test_sgd(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
@ -336,35 +333,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
self._build_params_dict_single(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
self._build_params_dict_single(weight, bias, lr=1e-2),
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
@ -530,16 +498,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
@ -562,17 +520,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
@ -694,16 +641,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
@ -774,15 +711,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
@ -815,12 +743,6 @@ class TestOptim(TestCase):
)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias], lr=1e-3, foreach=foreach
@ -915,16 +837,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
@ -989,16 +901,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adamax(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adamax(
[weight, bias],
@ -1021,12 +923,6 @@ class TestOptim(TestCase):
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
@ -1094,62 +990,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
centered=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
centered=True,
momentum=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
momentum=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
momentum=0.1,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
self._test_complex_2d(
lambda param: RMSprop(param, centered=True, foreach=foreach)
@ -1183,28 +1023,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: ASGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
t0=100,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: ASGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
# Ref: https://github.com/pytorch/pytorch/issues/84560
# self._test_complex_2d(optimizer)
self._test_complex_optimizer(
@ -1239,18 +1057,6 @@ class TestOptim(TestCase):
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Rprop(
self._build_params_dict(weight, bias, lr=1e-2),
lr=2e-4,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
atol=4e-5 if is_cuda_sm86 else None,
rtol=3e-5 if is_cuda_sm86 else None,
)
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
self._test_complex_optimizer(
lambda param: Rprop([param], lr=0.001, foreach=foreach)

View File

@ -359,6 +359,90 @@ class TestOptimRenewed(TestCase):
optimizer_cuda.step()
@optims(optim_db, dtypes=[torch.float32])
def test_param_groups_weight_decay(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
weight_kwargs = optim_input.kwargs
bias_kwargs = deepcopy(optim_input.kwargs)
bias_kwargs["weight_decay"] = 0.0
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([dict(params=[weight], **weight_kwargs), dict(params=[bias], **bias_kwargs)])
loss = (weight.mv(input) + bias).pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
if optim_cls.__name__ == "SparseAdam":
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
# which we know does NOT represent the expected use case!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
optimizer.step()
# Test that the direction of loss moved appropriately
if optim_input.kwargs.get("maximize", False):
self.assertGreater(loss.item(), initial_value)
else:
self.assertLess(loss.item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_param_groups_lr(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
for optim_input in all_optim_inputs:
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
optim_input.kwargs["lr"] = 1e-3
outer_kwargs = {"lr": 1e-28}
if optim_cls.__name__ == "Rprop":
# Allow min step size to be 0
outer_kwargs["step_sizes"] = (0, 50)
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
irrelevant_clone = irrelevant.clone()
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls(
[dict(params=[weight, bias], **optim_input.kwargs), dict(params=[irrelevant])],
**outer_kwargs)
loss = (weight.mv(input) + bias).pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
irrelevant.grad = torch.rand_like(irrelevant)
if optim_cls.__name__ == "SparseAdam":
# SparseAdam requires sparse gradients. For this test, we convert the Tensor layout,
# which we know does NOT represent the expected use case!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
irrelevant.grad = irrelevant.grad.to_sparse()
optimizer.step()
# Test that the direction of loss moved appropriately
if optim_input.kwargs.get("maximize", False):
self.assertGreater(loss.item(), initial_value)
else:
self.assertLess(loss.item(), initial_value)
# Test that irrelevant parameters were not updated since lr was almost 0
self.assertEqual(irrelevant, irrelevant_clone)
@optims(optim_db, dtypes=[torch.float32])
def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls

View File

@ -1191,6 +1191,16 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_groups_weight_decay",
),
),
),
OptimizerInfo(
@ -1436,6 +1446,22 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061"
),
"TestOptimRenewed",
"test_param_groups_weight_decay",
device_type="cpu",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061"
),
"TestOptimRenewed",
"test_param_groups_lr",
device_type="cpu",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors with list out of range, see https://github.com/pytorch/pytorch/issues/116061"
@ -1478,6 +1504,11 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_state_dict_deterministic",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
unittest.skip(
"SparseAdam does not support dense gradients, see #116507"