Adds a maximize flag to Adam (#68164)

Summary:
Solves the next most important use case in https://github.com/pytorch/pytorch/issues/68052.

I have kept the style as close to that in SGD as seemed reasonable, given the slight differences in their internal implementations.

All feedback welcome!

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68164

Reviewed By: VitalyFedyunin

Differential Revision: D32994129

Pulled By: albanD

fbshipit-source-id: 65c57c3f3dbbd3e3e5338d51def54482503e8850
This commit is contained in:
oliver
2021-12-13 05:52:07 -08:00
committed by Facebook GitHub Bot
parent fc37e5b3ed
commit 3d358a7678
6 changed files with 69 additions and 45 deletions

View File

@ -689,10 +689,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
sharded_optimizer.load_state_dict(sharded_optim_state_dict)
check_step()
for opt in [torch.optim.Adam]:
check_optimizer_equivalence(opt, maximize=False)
for opt in [torch.optim.SGD]:
for opt in [torch.optim.Adam, torch.optim.SGD]:
for maximize in (True, False):
check_optimizer_equivalence(opt, maximize=maximize)

View File

@ -480,58 +480,68 @@ class TestOptim(TestCase):
def test_adam(self):
for optimizer in [optim.Adam, optim_mt.Adam]:
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3)
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, maximize=maximize),
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize),
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize),
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3)
lr=1e-3, amsgrad=True, maximize=maximize),
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True)
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1)
)
self._test_basic_cases(
lambda weight, bias: optimizer(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3, amsgrad=True)
lr=1e-3, maximize=maximize),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3),
[lambda opt: ExponentialLR(opt, gamma=0.9)]
lr=1e-3, maximize=maximize),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)]
lr=1e-3, maximize=maximize),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)]
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9)]
lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt)]
lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True
)
self._test_basic_cases(
lambda weight, bias: optimizer(
lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3, amsgrad=True),
lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt)]
lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True
)
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
optimizer(None, lr=1e-2, betas=(1.0, 0.0))

View File

@ -23,6 +23,7 @@ class _FunctionalAdam(object):
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
maximize: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
@ -44,6 +45,7 @@ class _FunctionalAdam(object):
"weight_decay": weight_decay,
}
self.amsgrad = amsgrad
self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
@ -96,6 +98,7 @@ class _FunctionalAdam(object):
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults['beta1'],
beta2=self.defaults['beta2'],
lr=self.defaults['lr'],
@ -156,6 +159,7 @@ class _FunctionalAdam(object):
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
maximize=self.maximize,
beta1=self.defaults['beta1'],
beta2=self.defaults['beta2'],
lr=self.defaults['lr'],

View File

@ -73,7 +73,8 @@ def adam(params: List[Tensor],
beta2: float,
lr: float,
weight_decay: float,
eps: float):
eps: float,
maximize: bool):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
@ -81,7 +82,7 @@ def adam(params: List[Tensor],
for i, param in enumerate(params):
grad = grads[i]
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
@ -103,11 +104,11 @@ def adam(params: List[Tensor],
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
def adamw(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],

View File

@ -32,7 +32,7 @@ class Adam(Optimizer):
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
weight_decay=0, amsgrad=False, *, maximize: bool = False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@ -44,7 +44,7 @@ class Adam(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
@ -75,6 +75,7 @@ class Adam(Optimizer):
max_exp_avg_sq = []
params_with_grad = []
for p in group['params']:
if p.grad is not None:
if p.grad.is_sparse:
@ -82,6 +83,9 @@ class Adam(Optimizer):
params_with_grad.append(p)
grads.append(p.grad)
if group['maximize']:
grads = torch._foreach_neg(tuple(grads))
for p in params_with_grad:
state = self.state[p]

View File

@ -11,12 +11,17 @@ class Adam(Optimizer):
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
&\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},\: \\
\textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} /textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} /textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
@ -50,6 +55,8 @@ class Adam(Optimizer):
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
maximize (bool, optional): maximize the params based on the objective, instead of
minimizing (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
@ -58,7 +65,7 @@ class Adam(Optimizer):
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
weight_decay=0, amsgrad=False, *, maximize: bool = False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@ -70,7 +77,7 @@ class Adam(Optimizer):
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
@ -141,5 +148,6 @@ class Adam(Optimizer):
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'])
eps=group['eps'],
maximize=group['maximize'])
return loss