mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237 Approved by: https://github.com/albanD
288 lines
12 KiB
Python
288 lines
12 KiB
Python
import math
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from .optimizer import Optimizer
|
|
from typing import List, Optional
|
|
|
|
__all__ = ['RAdam', 'radam']
|
|
|
|
class RAdam(Optimizer):
|
|
r"""Implements RAdam algorithm.
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
|
|
\text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
|
|
\lambda \text{ (weightdecay)}, \\
|
|
&\hspace{13mm} \epsilon \text{ (epsilon)} \\
|
|
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
|
|
v_0 \leftarrow 0 \text{ ( second moment)}, \\
|
|
&\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\
|
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
|
&\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
|
&\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
|
|
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
|
|
&\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
|
|
&\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
|
|
&\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
|
|
&\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
|
|
2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
|
|
&\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
|
|
&\hspace{12mm} l_t \leftarrow \sqrt{ (1-\beta^t_2) / \big( v_t +\epsilon \big) } \\
|
|
&\hspace{12mm} r_t \leftarrow
|
|
\sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
|
|
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\
|
|
&\hspace{6mm}\textbf{else} \\
|
|
&\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
&\bf{return} \: \theta_t \\[-1.ex]
|
|
&\rule{110mm}{0.4pt} \\[-1.ex]
|
|
\end{aligned}
|
|
|
|
For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
|
|
|
|
Args:
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
parameter groups
|
|
lr (float, optional): learning rate (default: 1e-3)
|
|
betas (Tuple[float, float], optional): coefficients used for computing
|
|
running averages of gradient and its square (default: (0.9, 0.999))
|
|
eps (float, optional): term added to the denominator to improve
|
|
numerical stability (default: 1e-8)
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
foreach (bool, optional): whether foreach implementation of optimizer
|
|
is used (default: None)
|
|
|
|
.. _On the variance of the adaptive learning rate and beyond:
|
|
https://arxiv.org/abs/1908.03265
|
|
"""
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
weight_decay=0, foreach: Optional[bool] = None):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if not 0.0 <= betas[0] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
|
if not 0.0 <= betas[1] < 1.0:
|
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
|
foreach=foreach)
|
|
super(RAdam, self).__init__(params, defaults)
|
|
|
|
def __setstate__(self, state):
|
|
super().__setstate__(state)
|
|
for group in self.param_groups:
|
|
group.setdefault('foreach', None)
|
|
state_values = list(self.state.values())
|
|
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
|
if not step_is_tensor:
|
|
for s in state_values:
|
|
s['step'] = torch.tensor(float(s['step']))
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Args:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
params_with_grad = []
|
|
grads = []
|
|
exp_avgs = []
|
|
exp_avg_sqs = []
|
|
state_steps = []
|
|
beta1, beta2 = group['betas']
|
|
|
|
for p in group['params']:
|
|
if p.grad is not None:
|
|
params_with_grad.append(p)
|
|
if p.grad.is_sparse:
|
|
raise RuntimeError('RAdam does not support sparse gradients')
|
|
grads.append(p.grad)
|
|
|
|
state = self.state[p]
|
|
# Lazy state initialization
|
|
if len(state) == 0:
|
|
state['step'] = torch.tensor(0.)
|
|
# Exponential moving average of gradient values
|
|
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
# Exponential moving average of squared gradient values
|
|
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
exp_avgs.append(state['exp_avg'])
|
|
exp_avg_sqs.append(state['exp_avg_sq'])
|
|
state_steps.append(state['step'])
|
|
|
|
radam(params_with_grad,
|
|
grads,
|
|
exp_avgs,
|
|
exp_avg_sqs,
|
|
state_steps,
|
|
beta1=beta1,
|
|
beta2=beta2,
|
|
lr=group['lr'],
|
|
weight_decay=group['weight_decay'],
|
|
eps=group['eps'],
|
|
foreach=group['foreach'])
|
|
|
|
return loss
|
|
|
|
|
|
def radam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
|
foreach: bool = None,
|
|
*,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
r"""Functional API that performs RAdam algorithm computation.
|
|
|
|
See :class:`~torch.optim.RAdam` for details.
|
|
"""
|
|
|
|
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
|
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
|
|
|
|
if foreach is None:
|
|
# Placeholder for more complex foreach logic to be added when value is not set
|
|
foreach = False
|
|
|
|
if foreach and torch.jit.is_scripting():
|
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
|
|
|
if foreach and not torch.jit.is_scripting():
|
|
func = _multi_tensor_radam
|
|
else:
|
|
func = _single_tensor_radam
|
|
|
|
func(params,
|
|
grads,
|
|
exp_avgs,
|
|
exp_avg_sqs,
|
|
state_steps,
|
|
beta1=beta1,
|
|
beta2=beta2,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
eps=eps)
|
|
|
|
|
|
def _single_tensor_radam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
*,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
|
|
for i, param in enumerate(params):
|
|
grad = grads[i]
|
|
exp_avg = exp_avgs[i]
|
|
exp_avg_sq = exp_avg_sqs[i]
|
|
step_t = state_steps[i]
|
|
# update step
|
|
step_t += 1
|
|
step = step_t.item()
|
|
|
|
bias_correction1 = 1 - beta1 ** step
|
|
bias_correction2 = 1 - beta2 ** step
|
|
|
|
if weight_decay != 0:
|
|
grad = grad.add(param, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
|
|
# correcting bias for the first moving moment
|
|
bias_corrected_exp_avg = exp_avg / bias_correction1
|
|
|
|
# maximum length of the approximated SMA
|
|
rho_inf = 2 / (1 - beta2) - 1
|
|
# compute the length of the approximated SMA
|
|
rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
|
|
|
|
if rho_t > 5.:
|
|
# Compute the variance rectification term and update parameters accordingly
|
|
rect = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
|
|
adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq.sqrt().add_(eps)
|
|
|
|
param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0)
|
|
else:
|
|
param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
|
|
|
|
|
|
def _multi_tensor_radam(params: List[Tensor],
|
|
grads: List[Tensor],
|
|
exp_avgs: List[Tensor],
|
|
exp_avg_sqs: List[Tensor],
|
|
state_steps: List[Tensor],
|
|
*,
|
|
beta1: float,
|
|
beta2: float,
|
|
lr: float,
|
|
weight_decay: float,
|
|
eps: float):
|
|
|
|
if len(params) == 0:
|
|
return
|
|
|
|
# Update steps
|
|
torch._foreach_add_(state_steps, 1)
|
|
|
|
# maximum length of the approximated SMA
|
|
rho_inf = 2 / (1 - beta2) - 1
|
|
# compute the length of the approximated SMA
|
|
rho_t_list = [rho_inf - 2 * step.item() * (beta2 ** step.item()) / (1 - beta2 ** step.item()) for step in state_steps]
|
|
|
|
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
|
|
bias_correction2 = [1 - beta2 ** step.item() for step in state_steps]
|
|
if weight_decay != 0:
|
|
torch._foreach_add_(grads, params, alpha=weight_decay)
|
|
|
|
# Decay the first and second moment running average coefficient
|
|
torch._foreach_mul_(exp_avgs, beta1)
|
|
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
|
|
|
|
torch._foreach_mul_(exp_avg_sqs, beta2)
|
|
torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
|
|
|
|
rect = [math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t))
|
|
if rho_t > 5 else 0 for rho_t in rho_t_list]
|
|
unrectified = [0 if rect > 0 else 1. for rect in rect]
|
|
|
|
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
|
|
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
|
|
denom = torch._foreach_div(exp_avg_sq_sqrt, bias_correction_sqrt)
|
|
step_size = [(lr * rect / bc) * -1 for rect, bc in zip(rect, bias_correction1)]
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
|
|
|
|
denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in exp_avgs]
|
|
step_size = [(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)]
|
|
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
|