mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds a maximize
flag to Adam (#68164)
Summary: Solves the next most important use case in https://github.com/pytorch/pytorch/issues/68052. I have kept the style as close to that in SGD as seemed reasonable, given the slight differences in their internal implementations. All feedback welcome! cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/68164 Reviewed By: VitalyFedyunin Differential Revision: D32994129 Pulled By: albanD fbshipit-source-id: 65c57c3f3dbbd3e3e5338d51def54482503e8850
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fc37e5b3ed
commit
3d358a7678
@ -689,10 +689,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
|
||||
sharded_optimizer.load_state_dict(sharded_optim_state_dict)
|
||||
check_step()
|
||||
|
||||
for opt in [torch.optim.Adam]:
|
||||
check_optimizer_equivalence(opt, maximize=False)
|
||||
|
||||
for opt in [torch.optim.SGD]:
|
||||
for opt in [torch.optim.Adam, torch.optim.SGD]:
|
||||
for maximize in (True, False):
|
||||
check_optimizer_equivalence(opt, maximize=maximize)
|
||||
|
||||
|
@ -480,58 +480,68 @@ class TestOptim(TestCase):
|
||||
def test_adam(self):
|
||||
for optimizer in [optim.Adam, optim_mt.Adam]:
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3)
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3)
|
||||
lr=1e-3, amsgrad=True, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1)
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3, amsgrad=True)
|
||||
lr=1e-3, maximize=maximize),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9)]
|
||||
lr=1e-3, maximize=maximize),
|
||||
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3),
|
||||
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)]
|
||||
lr=1e-3, maximize=maximize),
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3),
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)]
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
|
||||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9)]
|
||||
lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
|
||||
[lambda opt: ExponentialLR(opt, gamma=0.9),
|
||||
lambda opt: ReduceLROnPlateau(opt)]
|
||||
lambda opt: ReduceLROnPlateau(opt)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3, amsgrad=True),
|
||||
lr=1e-3, amsgrad=True, maximize=maximize),
|
||||
[lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||||
lambda opt: ReduceLROnPlateau(opt)]
|
||||
lambda opt: ReduceLROnPlateau(opt)],
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
|
||||
optimizer(None, lr=1e-2, betas=(1.0, 0.0))
|
||||
|
@ -23,6 +23,7 @@ class _FunctionalAdam(object):
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.0,
|
||||
amsgrad: bool = False,
|
||||
maximize: bool = False,
|
||||
_allow_empty_param_list: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
@ -44,6 +45,7 @@ class _FunctionalAdam(object):
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
self.amsgrad = amsgrad
|
||||
self.maximize = maximize
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
@ -96,6 +98,7 @@ class _FunctionalAdam(object):
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=self.amsgrad,
|
||||
maximize=self.maximize,
|
||||
beta1=self.defaults['beta1'],
|
||||
beta2=self.defaults['beta2'],
|
||||
lr=self.defaults['lr'],
|
||||
@ -156,6 +159,7 @@ class _FunctionalAdam(object):
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=self.amsgrad,
|
||||
maximize=self.maximize,
|
||||
beta1=self.defaults['beta1'],
|
||||
beta2=self.defaults['beta2'],
|
||||
lr=self.defaults['lr'],
|
||||
|
@ -73,7 +73,8 @@ def adam(params: List[Tensor],
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float):
|
||||
eps: float,
|
||||
maximize: bool):
|
||||
r"""Functional API that performs Adam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adam` for details.
|
||||
@ -81,7 +82,7 @@ def adam(params: List[Tensor],
|
||||
|
||||
for i, param in enumerate(params):
|
||||
|
||||
grad = grads[i]
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step = state_steps[i]
|
||||
@ -103,11 +104,11 @@ def adam(params: List[Tensor],
|
||||
else:
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
def adamw(params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
|
@ -32,7 +32,7 @@ class Adam(Optimizer):
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
weight_decay=0, amsgrad=False, *, maximize: bool = False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
@ -44,7 +44,7 @@ class Adam(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
|
||||
super(Adam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
@ -75,6 +75,7 @@ class Adam(Optimizer):
|
||||
max_exp_avg_sq = []
|
||||
params_with_grad = []
|
||||
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
if p.grad.is_sparse:
|
||||
@ -82,6 +83,9 @@ class Adam(Optimizer):
|
||||
params_with_grad.append(p)
|
||||
grads.append(p.grad)
|
||||
|
||||
if group['maximize']:
|
||||
grads = torch._foreach_neg(tuple(grads))
|
||||
|
||||
for p in params_with_grad:
|
||||
state = self.state[p]
|
||||
|
||||
|
@ -11,12 +11,17 @@ class Adam(Optimizer):
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
|
||||
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\
|
||||
&\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},\: \\
|
||||
\textit{maximize} \\
|
||||
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
||||
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
|
||||
&\hspace{5mm} /textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm} /textbf{else} \\
|
||||
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
|
||||
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
||||
@ -50,6 +55,8 @@ class Adam(Optimizer):
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
@ -58,7 +65,7 @@ class Adam(Optimizer):
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
weight_decay=0, amsgrad=False, *, maximize: bool = False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
@ -70,7 +77,7 @@ class Adam(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
|
||||
super(Adam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
@ -141,5 +148,6 @@ class Adam(Optimizer):
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'])
|
||||
eps=group['eps'],
|
||||
maximize=group['maximize'])
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user