mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a tensor lr test for optimizers (#123139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123139 Approved by: https://github.com/albanD ghstack dependencies: #123134
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb8fc30e4a
commit
f2838c99a0
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: optimizer"]
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
@ -132,110 +131,6 @@ class TestOptim(TestCase):
|
||||
sum([rosenbrock(param_t) for param_t in params_t]),
|
||||
)
|
||||
|
||||
def _test_basic_cases_template(
|
||||
self,
|
||||
weight_tensor,
|
||||
bias_tensor,
|
||||
input_tensor,
|
||||
constructor,
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=False,
|
||||
):
|
||||
maximize_options = {False, constructor_accepts_maximize}
|
||||
foreach_options = {False, constructor_accepts_foreach}
|
||||
|
||||
four_arg_constructor = constructor
|
||||
if constructor_accepts_maximize and constructor_accepts_foreach:
|
||||
pass
|
||||
elif constructor_accepts_maximize:
|
||||
|
||||
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
|
||||
self.assertFalse(foreach)
|
||||
return constructor(weight, bias, maximize)
|
||||
|
||||
elif constructor_accepts_foreach:
|
||||
|
||||
def four_arg_constructor(weight, bias, maximize, foreach):
|
||||
self.assertFalse(maximize)
|
||||
return constructor(weight, bias, foreach)
|
||||
|
||||
else:
|
||||
|
||||
def four_arg_constructor(weight, bias, maximize, foreach):
|
||||
self.assertFalse(maximize or foreach)
|
||||
return constructor(weight, bias)
|
||||
|
||||
for maximize, foreach in itertools.product(maximize_options, foreach_options):
|
||||
with torch.no_grad():
|
||||
weight = Parameter(weight_tensor.clone().detach())
|
||||
bias = Parameter(bias_tensor.clone().detach())
|
||||
input = input_tensor.clone().detach().requires_grad_()
|
||||
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
|
||||
|
||||
def fn():
|
||||
optimizer.zero_grad()
|
||||
y = weight.mv(input)
|
||||
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
|
||||
y = y.cuda(bias.get_device())
|
||||
loss = (y + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
initial_value = fn().item()
|
||||
for _ in range(200):
|
||||
optimizer.step(fn)
|
||||
if maximize:
|
||||
self.assertGreater(fn().item(), initial_value)
|
||||
else:
|
||||
self.assertLess(fn().item(), initial_value)
|
||||
|
||||
|
||||
def _test_basic_cases(
|
||||
self,
|
||||
constructor,
|
||||
constructor_accepts_maximize=False,
|
||||
constructor_accepts_foreach=False,
|
||||
):
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5),
|
||||
torch.randn(10),
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
# non-contiguous parameters
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5, 2)[..., 0],
|
||||
torch.randn(10, 2)[..., 0],
|
||||
torch.randn(5),
|
||||
constructor,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
constructor,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
# Multi-GPU
|
||||
if not torch.cuda.device_count() > 1:
|
||||
return
|
||||
self._test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(0),
|
||||
torch.randn(10).cuda(1),
|
||||
torch.randn(5).cuda(0),
|
||||
constructor,
|
||||
constructor_accepts_maximize,
|
||||
constructor_accepts_foreach,
|
||||
)
|
||||
|
||||
|
||||
def test_sgd_sparse(self):
|
||||
for foreach in (False, True):
|
||||
@ -250,21 +145,6 @@ class TestOptim(TestCase):
|
||||
)
|
||||
|
||||
|
||||
def test_adamw(self):
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize, foreach: AdamW(
|
||||
[weight, bias],
|
||||
lr=torch.tensor(1e-3),
|
||||
weight_decay=1,
|
||||
amsgrad=True,
|
||||
maximize=maximize,
|
||||
foreach=False, # foreach for lr tensors tested in multi configs
|
||||
),
|
||||
constructor_accepts_maximize=True,
|
||||
constructor_accepts_foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_adam(self):
|
||||
self._test_rosenbrock_sparse(
|
||||
lambda params: SparseAdam(params, lr=4e-2), [], True
|
||||
|
@ -199,6 +199,56 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertLess(closure().item(), initial_value)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_tensor_lr(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
|
||||
for optim_input in all_optim_inputs:
|
||||
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
|
||||
weight_c = weight.clone().detach().requires_grad_(True)
|
||||
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
|
||||
bias_c = bias.clone().detach().requires_grad_(True)
|
||||
inpt = torch.randn(5, device=device, dtype=dtype)
|
||||
|
||||
kwargs = optim_input.kwargs
|
||||
if "lr" in kwargs:
|
||||
del kwargs["lr"]
|
||||
|
||||
kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
|
||||
optimizer_r = optim_cls([weight, bias], **kwargs)
|
||||
|
||||
try:
|
||||
kwargs["lr"] = torch.tensor(kwargs["lr"])
|
||||
optimizer = optim_cls([weight_c, bias_c], **kwargs)
|
||||
except ValueError as e:
|
||||
self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
|
||||
continue
|
||||
|
||||
def closure(optim, w, b, i):
|
||||
optim.zero_grad()
|
||||
loss = (w.mv(i) + b).pow(2).sum()
|
||||
loss.backward()
|
||||
if optim_info.only_supports_sparse_grads:
|
||||
# For this test, we naively convert the Tensor layout, which we know does
|
||||
# NOT represent the expected use case for optims like SparseAdam!
|
||||
w.grad = w.grad.to_sparse()
|
||||
b.grad = b.grad.to_sparse()
|
||||
return loss
|
||||
|
||||
for _ in range(5):
|
||||
if optim_info.step_requires_closure:
|
||||
optimizer_r.step(functools.partial(closure, optimizer_r, weight, bias, inpt))
|
||||
optimizer.step(functools.partial(closure, optimizer, weight_c, bias_c, inpt))
|
||||
else:
|
||||
closure(optimizer_r, weight, bias, inpt)
|
||||
closure(optimizer, weight_c, bias_c, inpt)
|
||||
|
||||
self.assertEqual(weight, weight_c)
|
||||
self.assertEqual(bias, bias_c)
|
||||
|
||||
|
||||
@skipMPS
|
||||
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
|
||||
def test_complex(self, device, dtype, optim_info):
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import sys
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
@ -1047,6 +1048,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -1150,6 +1157,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115607"
|
||||
@ -1250,6 +1263,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1327,6 +1346,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
|
||||
"TestOptimRenewed",
|
||||
@ -1441,6 +1466,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1512,6 +1543,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See discrepancy in https://github.com/pytorch/pytorch/issues/115607"
|
||||
@ -1619,6 +1656,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
@ -1647,6 +1690,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||||
@ -1733,6 +1782,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
@ -1828,6 +1883,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -1922,6 +1983,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_forloop_goes_right_direction_multigpu",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/115679"
|
||||
@ -2031,6 +2098,12 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("Fails fix point assertion on 3.8, see #97811"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
active_if=sys.version_info < (3, 9) and sys.version_info > (3, 7),
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Dynamo memory usage is flaky, see https://github.com/pytorch/pytorch/issues/116046"
|
||||
@ -2148,6 +2221,11 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_param_groups_lr",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||||
"TestOptimRenewed",
|
||||
"test_tensor_lr",
|
||||
),
|
||||
DecorateInfo(
|
||||
unittest.skip(
|
||||
"SparseAdam does not support dense gradients, see #116507"
|
||||
|
Reference in New Issue
Block a user