mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adafactor forloop basic impl (#129905)
#109581 At this point, the vanilla implementation (the default) is good. Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor Specifically, the impl in this PR, which attempts to replicate the paper, ``` optim = torch.optim.Adafactor([weight]) ``` is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor ``` optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False) ``` is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor ``` optim = keras.optimizers.Adafactor(learning_rate=0.01) ``` The three results respectively for the same randomly generated weights: ``` # ours tensor([[ 0.3807594, -0.3912092], [ 0.0762539, 0.5377805], [ 0.2459473, 0.4662207]]) # pytorch-optimizer tensor([[ 0.3807592, -0.3912172], [ 0.0762507, 0.5377818], [ 0.2459457, 0.4662213]]) # keras array([[ 0.38076326, -0.39121315], [ 0.0762547 , 0.5377859 ], [ 0.24594972, 0.46622536]], dtype=float32) ``` This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences: * keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01. * We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling. <details> Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing My script repro: ``` import torch from pytorch_optimizer import AdaFactor torch.set_printoptions(precision=7) weight = torch.tensor([[ 0.37697506, -0.39500135], [ 0.07246649, 0.53399765], [ 0.24216151, 0.46243715]], dtype=torch.float32) # bias = torch.tensor([0, 0], dtype=torch.float32) weight.grad = torch.tensor([[-0.5940447, -0.7743838], [-0.5940447, -0.7743838], [-0.5940447, -0.7743838]], dtype=torch.float32) # bias.grad = torch.tensor([-2.5027974, 1.5422692], dtype=torch.float32) weight_c = weight.clone() weight_c.grad = weight.grad.clone() optim = torch.optim.Adafactor([weight]) optim.step() print(weight) optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False) optim_c.step() print(weight_c) ``` <details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8956c9fe6
commit
9c4cf866c2
@ -134,6 +134,7 @@ Algorithms
|
||||
:nosignatures:
|
||||
|
||||
Adadelta
|
||||
Adafactor
|
||||
Adagrad
|
||||
Adam
|
||||
AdamW
|
||||
@ -177,6 +178,7 @@ Below is a table showing the available and default implementations of each algor
|
||||
:delim: ;
|
||||
|
||||
:class:`Adadelta`;foreach;yes;no
|
||||
:class:`Adafactor`;for-loop;no;no
|
||||
:class:`Adagrad`;foreach;yes;yes (cpu only)
|
||||
:class:`Adam`;foreach;yes;yes
|
||||
:class:`AdamW`;foreach;yes;yes
|
||||
@ -198,6 +200,7 @@ Below table is showing the stability status for fused implementations:
|
||||
:delim: ;
|
||||
|
||||
:class:`Adadelta`;unsupported;unsupported;unsupported
|
||||
:class:`Adafactor`;unsupported;unsupported;unsupported
|
||||
:class:`Adagrad`;beta;unsupported;unsupported
|
||||
:class:`Adam`;beta;stable;beta
|
||||
:class:`AdamW`;beta;stable;beta
|
||||
|
@ -1383,7 +1383,6 @@ class TestOptimRenewed(TestCase):
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_can_load_older_state_dict(self, device, dtype, optim_info):
|
||||
new_flags = ["maximize", "foreach", "fused", "differentiable", "capturable"]
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
|
||||
@ -1417,7 +1416,7 @@ class TestOptimRenewed(TestCase):
|
||||
old_state_dict = deepcopy(optimizer.state_dict())
|
||||
old_state_dict_pg = old_state_dict["param_groups"]
|
||||
for group in old_state_dict_pg:
|
||||
for flag in new_flags:
|
||||
for flag in optim_info.not_og_supported_flags:
|
||||
if flag in group:
|
||||
del group[flag]
|
||||
|
||||
|
@ -7,6 +7,7 @@ future.
|
||||
"""
|
||||
|
||||
from torch.optim import lr_scheduler, swa_utils
|
||||
from torch.optim._adafactor import Adafactor
|
||||
from torch.optim.adadelta import Adadelta
|
||||
from torch.optim.adagrad import Adagrad
|
||||
from torch.optim.adam import Adam
|
||||
|
450
torch/optim/_adafactor.py
Normal file
450
torch/optim/_adafactor.py
Normal file
@ -0,0 +1,450 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from .optimizer import (
|
||||
_disable_dynamo_if_unsupported,
|
||||
_get_scalar_dtype,
|
||||
_maximize_doc,
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
__all__ = ["Adafactor", "adafactor"]
|
||||
|
||||
|
||||
class Adafactor(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-2,
|
||||
beta2_decay: float = -0.8,
|
||||
eps: Tuple[Optional[float], float] = (None, 1e-3),
|
||||
d: float = 1.0,
|
||||
weight_decay: float = 0.0,
|
||||
*,
|
||||
maximize: bool = False,
|
||||
):
|
||||
if isinstance(lr, Tensor) and lr.numel() != 1:
|
||||
raise ValueError("Tensor lr must be 1-element")
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
|
||||
if not 0.0 >= beta2_decay:
|
||||
raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}")
|
||||
if eps[0] is not None and not 0.0 <= eps[0]:
|
||||
raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}")
|
||||
if not 0.0 <= eps[1]:
|
||||
raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}")
|
||||
if not 1.0 <= d:
|
||||
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
beta2_decay=beta2_decay,
|
||||
eps=eps,
|
||||
d=d,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
row_vars,
|
||||
col_vars,
|
||||
variances,
|
||||
state_steps,
|
||||
):
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
if torch.is_complex(p):
|
||||
raise RuntimeError("Adafactor does not support complex parameters")
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients")
|
||||
|
||||
params_with_grad.append(p)
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
|
||||
if p.grad.dim() > 1:
|
||||
row_shape = list(p.grad.shape)
|
||||
row_shape[-1] = 1
|
||||
# Row factor of variance, NOT the same shape as grads (will be reduced along last dim)
|
||||
state["row_var"] = p.grad.new_zeros(row_shape)
|
||||
|
||||
col_shape = list(p.grad.shape)
|
||||
col_shape[-2] = 1
|
||||
# Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim)
|
||||
state["col_var"] = p.grad.new_zeros(col_shape)
|
||||
else:
|
||||
state["variance"] = torch.zeros_like(
|
||||
p.grad, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
row_vars.append(state.get("row_var", None))
|
||||
col_vars.append(state.get("col_var", None))
|
||||
variances.append(state.get("variance", None))
|
||||
state_steps.append(state["step"])
|
||||
return False # has_complex
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
r"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self._cuda_graph_capture_health_check()
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
row_vars: List[Optional[Tensor]] = []
|
||||
col_vars: List[Optional[Tensor]] = []
|
||||
variances: List[Optional[Tensor]] = []
|
||||
state_steps: List[Tensor] = []
|
||||
eps1, eps2 = group["eps"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
row_vars,
|
||||
col_vars,
|
||||
variances,
|
||||
state_steps,
|
||||
)
|
||||
|
||||
adafactor(
|
||||
params_with_grad,
|
||||
grads,
|
||||
row_vars,
|
||||
col_vars,
|
||||
variances,
|
||||
state_steps,
|
||||
d=group["d"],
|
||||
lr=group["lr"],
|
||||
beta2_decay=group["beta2_decay"],
|
||||
weight_decay=group["weight_decay"],
|
||||
eps1=eps1,
|
||||
eps2=eps2,
|
||||
maximize=group["maximize"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
Adafactor.__doc__ = (
|
||||
r"""Implements Adafactor algorithm.
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{input} : \gamma \text{(lr)}, \: \tau
|
||||
\text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\
|
||||
&\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
|
||||
&\hspace{15mm} \: \lambda \text{(weight decay)},
|
||||
\: \textit{maximize} \\
|
||||
&\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\
|
||||
&\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\
|
||||
&\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\
|
||||
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
|
||||
|
||||
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
|
||||
&\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
|
||||
&\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\
|
||||
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\
|
||||
&\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2,
|
||||
\text{RMS}(\theta_{t-1}))\rho_t \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
|
||||
&\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\
|
||||
&\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\
|
||||
&\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\
|
||||
&\hspace{10mm}\widehat{V}_t \leftarrow
|
||||
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
|
||||
&\hspace{5mm}\textbf{else} \\
|
||||
&\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\
|
||||
&\hspace{5mm}U_t \leftarrow
|
||||
\frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
|
||||
&\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
|
||||
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\
|
||||
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
&\bf{return} \: \theta_t \\[-1.ex]
|
||||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||||
\end{aligned}
|
||||
|
||||
For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
|
||||
"""
|
||||
+ rf"""
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
|
||||
learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all.
|
||||
Deviating from the paper, this implementation uses lr for applying weight
|
||||
decay and as the maximum value for relative step size rho_t. Note that in
|
||||
the paper, a constant of 0.01 is used as the maximum value for relative
|
||||
step size, and so we set 0.01 as the default value. (default: 1e-2)
|
||||
beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
|
||||
to the coefficient used for computing the running average of the gradient
|
||||
squared. (default: -0.8)
|
||||
eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
|
||||
of the update calculation to improve numerical stability. This use of epsilon1
|
||||
deviates from the algorithm written in the paper! See note below for more details.
|
||||
epsilon2 is the term used to avoid having too small a weight update when applying
|
||||
parameter scaling. (default: (None, 1e-3))
|
||||
d (float, optional): the clipping threshold, used to avoid larger-than-desired
|
||||
updates.
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
{_maximize_doc}"""
|
||||
+ r"""
|
||||
.. Note::
|
||||
The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern
|
||||
and implementations in some other frameworks with its use of learning rate and
|
||||
:math:`\epsilon_1`.
|
||||
|
||||
Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not
|
||||
use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
|
||||
affect the step size.
|
||||
|
||||
This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
|
||||
\end{aligned}
|
||||
|
||||
This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as
|
||||
the maximum value of :math:`\rho_t`
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
|
||||
\end{aligned}
|
||||
|
||||
Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should
|
||||
be computed, and so we use the learning rate as a coefficient for decoupled weight
|
||||
decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.
|
||||
|
||||
Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
|
||||
presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as
|
||||
a stabilizing term when the squared gradient becomes small.
|
||||
|
||||
This stabilization can be written as
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\
|
||||
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\
|
||||
&\hspace{5mm}\widehat{V}_t \leftarrow
|
||||
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
|
||||
&\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
|
||||
\end{aligned}
|
||||
|
||||
where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
|
||||
are left alone, and we apply :math:`\epsilon_1` at the final calculation of
|
||||
the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.
|
||||
|
||||
This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which
|
||||
apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
|
||||
not in the calculations after:
|
||||
|
||||
.. math::
|
||||
\begin{aligned}
|
||||
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\
|
||||
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
|
||||
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\
|
||||
&\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\
|
||||
&\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\
|
||||
\end{aligned}
|
||||
|
||||
|
||||
.. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
|
||||
https://arxiv.org/pdf/1804.04235
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
|
||||
# so row_var and col_var will be None while variance will be filled.
|
||||
# Contrarily, for a grad with multiple dimensions, we will factor along the last
|
||||
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
|
||||
row_vars: List[Optional[Tensor]],
|
||||
col_vars: List[Optional[Tensor]],
|
||||
variances: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
d: float,
|
||||
lr: Union[Tensor, float],
|
||||
beta2_decay: float,
|
||||
weight_decay: float,
|
||||
eps1: Optional[float],
|
||||
eps2: float,
|
||||
maximize: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
assert (
|
||||
grad_scale is None and found_inf is None
|
||||
), "Grad scaling should occur outside of optimizer.step()"
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||
# a float since most people using JIT are using floats
|
||||
assert isinstance(lr, float)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
step_t = state_steps[i]
|
||||
row_var = row_vars[i]
|
||||
col_var = col_vars[i]
|
||||
variance = variances[i]
|
||||
if eps1 is None:
|
||||
eps1 = torch.finfo(param.dtype).eps
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
step_float = step_t.item()
|
||||
|
||||
beta2_t = 1 - step_float**beta2_decay
|
||||
rho_t = min(lr, 1 / (step_float**0.5))
|
||||
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
|
||||
|
||||
# Perform stepweight decay
|
||||
if weight_decay != 0:
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
|
||||
if grad.dim() > 1:
|
||||
assert (
|
||||
row_var is not None and col_var is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
|
||||
row_mean = (
|
||||
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
|
||||
)
|
||||
row_var.lerp_(row_mean, 1 - beta2_t)
|
||||
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
|
||||
col_mean = (
|
||||
torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
|
||||
)
|
||||
col_var.lerp_(col_mean, 1 - beta2_t)
|
||||
var_estimate = row_var @ col_var
|
||||
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
|
||||
else:
|
||||
assert (
|
||||
variance is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
grad_squared = grad * grad
|
||||
variance.lerp_(grad_squared, 1 - beta2_t)
|
||||
# avoid writing into variance during update
|
||||
var_estimate = variance.clone()
|
||||
|
||||
# square the eps1 as we sqrt after to keep eps1's magnitude
|
||||
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_()
|
||||
update.mul_(grad)
|
||||
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))
|
||||
param.add_(update, alpha=-alpha / denom)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
|
||||
def adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
row_vars: List[Optional[Tensor]],
|
||||
col_vars: List[Optional[Tensor]],
|
||||
variances: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
d: float,
|
||||
lr: Union[float, Tensor],
|
||||
beta2_decay: float,
|
||||
weight_decay: float,
|
||||
eps1: float,
|
||||
eps2: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs Adafactor algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adafactor` for details.
|
||||
"""
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"`state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
func = _single_tensor_adafactor
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
row_vars,
|
||||
col_vars,
|
||||
variances,
|
||||
state_steps,
|
||||
d=d,
|
||||
lr=lr,
|
||||
beta2_decay=beta2_decay,
|
||||
weight_decay=weight_decay,
|
||||
eps1=eps1,
|
||||
eps2=eps2,
|
||||
maximize=maximize,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
has_complex=has_complex,
|
||||
)
|
@ -13,6 +13,7 @@ from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import (
|
||||
Adadelta,
|
||||
Adafactor,
|
||||
Adagrad,
|
||||
Adam,
|
||||
Adamax,
|
||||
@ -119,7 +120,15 @@ class OptimizerInfo:
|
||||
),
|
||||
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
|
||||
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
|
||||
supported_impls: Tuple[str] = ("foreach", "differentiable"),
|
||||
supported_impls: Tuple[str, ...] = ("foreach", "differentiable"),
|
||||
# A subset of all flags, signifying which ones were only supported after the
|
||||
# original optimizer had already been released. aka impls where we need to check BC.
|
||||
not_og_supported_flags: Tuple[str, ...] = (
|
||||
"foreach",
|
||||
"differentiable",
|
||||
"maximize",
|
||||
"capturable",
|
||||
),
|
||||
# the optim supports passing in sparse gradients as well as dense grads
|
||||
supports_sparse: bool = False,
|
||||
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
||||
@ -139,12 +148,13 @@ class OptimizerInfo:
|
||||
skips=(), # Indicates which tests to skip
|
||||
decorators=None, # Additional decorators to apply to generated tests
|
||||
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
||||
supports_fused_on: Tuple[str] = (),
|
||||
supports_fused_on: Tuple[str, ...] = (),
|
||||
):
|
||||
self.optim_cls = optim_cls
|
||||
self.optim_inputs_func = optim_inputs_func
|
||||
self.scheduler_inputs = scheduler_inputs
|
||||
self.supported_impls = supported_impls
|
||||
self.not_og_supported_flags = not_og_supported_flags
|
||||
self.supports_sparse = supports_sparse
|
||||
self.metadata_for_sparse = metadata_for_sparse
|
||||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||||
@ -347,6 +357,79 @@ def optim_error_inputs_func_adadelta(device, dtype):
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adafactor(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "lr": 0.01},
|
||||
desc="nonzero weight_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||||
desc="maximize",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"beta2_decay": -1.0},
|
||||
desc="non-default beta2_decay",
|
||||
),
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs={"d": 1.5},
|
||||
desc="non-default clipping threshold d",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def optim_error_inputs_func_adafactor(device, dtype):
|
||||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||||
if _get_device_type(device) == "cpu":
|
||||
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
|
||||
complex_param.grad = torch.rand_like(complex_param)
|
||||
error_inputs += [
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(eps=(-1e-30, 1e-3)),
|
||||
desc="epsilon1 should be >= 0",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="epsilon1 should be >= 0",
|
||||
),
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(d=0.0),
|
||||
desc="invalid d",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="Clipping threshold d should be >= 1",
|
||||
),
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=None,
|
||||
kwargs=dict(beta2_decay=0.8),
|
||||
desc="invalid beta2_decay",
|
||||
),
|
||||
error_type=ValueError,
|
||||
error_regex="beta2_decay should be <= 0",
|
||||
),
|
||||
ErrorOptimizerInput(
|
||||
OptimizerInput(
|
||||
params=[complex_param],
|
||||
kwargs=dict(),
|
||||
desc="does not support complex parameters",
|
||||
),
|
||||
error_type=RuntimeError,
|
||||
error_regex="Adafactor does not support complex parameters",
|
||||
error_on=OptimizerErrorEnum.STEP_ERROR,
|
||||
),
|
||||
]
|
||||
return error_inputs
|
||||
|
||||
|
||||
def optim_inputs_func_adagrad(device, dtype=None):
|
||||
return [
|
||||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||||
@ -1170,11 +1253,27 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
),
|
||||
),
|
||||
OptimizerInfo(
|
||||
Adafactor,
|
||||
optim_inputs_func=optim_inputs_func_adafactor,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adafactor,
|
||||
supported_impls=(),
|
||||
not_og_supported_flags=(),
|
||||
supports_complex=False,
|
||||
skips=(),
|
||||
),
|
||||
OptimizerInfo(
|
||||
Adagrad,
|
||||
optim_inputs_func=optim_inputs_func_adagrad,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adagrad,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
not_og_supported_flags=(
|
||||
"foreach",
|
||||
"differentiable",
|
||||
"fused",
|
||||
"maximize",
|
||||
"capturable",
|
||||
),
|
||||
supports_fused_on=("cpu",),
|
||||
supports_sparse=True,
|
||||
metadata_for_sparse=(
|
||||
@ -1258,6 +1357,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
not_og_supported_flags=(
|
||||
"foreach",
|
||||
"differentiable",
|
||||
"fused",
|
||||
"maximize",
|
||||
"capturable",
|
||||
),
|
||||
supports_fused_on=("cpu", "cuda", "mps"),
|
||||
decorators=(
|
||||
# Expected floating point error between fused and compiled forloop
|
||||
@ -1380,6 +1486,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
optim_inputs_func=optim_inputs_func_adamw,
|
||||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
not_og_supported_flags=(
|
||||
"foreach",
|
||||
"differentiable",
|
||||
"fused",
|
||||
"maximize",
|
||||
"capturable",
|
||||
),
|
||||
supports_fused_on=("cpu", "cuda", "mps"),
|
||||
decorators=(
|
||||
# Expected error between compiled forloop and fused optimizers
|
||||
@ -1779,6 +1892,13 @@ optim_db: List[OptimizerInfo] = [
|
||||
),
|
||||
optim_error_inputs_func=optim_error_inputs_func_sgd,
|
||||
supported_impls=("foreach", "differentiable", "fused"),
|
||||
not_og_supported_flags=(
|
||||
"foreach",
|
||||
"differentiable",
|
||||
"fused",
|
||||
"maximize",
|
||||
"capturable",
|
||||
),
|
||||
supports_sparse=True,
|
||||
metadata_for_sparse=(
|
||||
{
|
||||
|
Reference in New Issue
Block a user