Adafactor forloop basic impl (#129905)

#109581

At this point, the vanilla implementation (the default) is good.
Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor

Specifically, the impl in this PR, which attempts to replicate the paper,
```
optim = torch.optim.Adafactor([weight])
```
is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor
```
optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False)
```
is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor
```
optim = keras.optimizers.Adafactor(learning_rate=0.01)
```

The three results respectively for the same randomly generated weights:
```
# ours
tensor([[ 0.3807594, -0.3912092],
        [ 0.0762539,  0.5377805],
        [ 0.2459473,  0.4662207]])

# pytorch-optimizer
tensor([[ 0.3807592, -0.3912172],
        [ 0.0762507,  0.5377818],
        [ 0.2459457,  0.4662213]])

# keras
array([[ 0.38076326, -0.39121315],
        [ 0.0762547 ,  0.5377859 ],
        [ 0.24594972,  0.46622536]], dtype=float32)
```

This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences:
* keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01.
* We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling.

<details>

Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing

My script repro:

```
import torch
from pytorch_optimizer import AdaFactor
torch.set_printoptions(precision=7)

weight = torch.tensor([[ 0.37697506, -0.39500135],
        [ 0.07246649,  0.53399765],
        [ 0.24216151,  0.46243715]], dtype=torch.float32)
# bias = torch.tensor([0, 0], dtype=torch.float32)

weight.grad = torch.tensor([[-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838]], dtype=torch.float32)
# bias.grad = torch.tensor([-2.5027974,  1.5422692], dtype=torch.float32)

weight_c = weight.clone()
weight_c.grad = weight.grad.clone()

optim = torch.optim.Adafactor([weight])
optim.step()
print(weight)

optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False)
optim_c.step()
print(weight_c)
```

<details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-07-24 14:22:47 -07:00
committed by PyTorch MergeBot
parent e8956c9fe6
commit 9c4cf866c2
5 changed files with 577 additions and 4 deletions

View File

@ -134,6 +134,7 @@ Algorithms
:nosignatures:
Adadelta
Adafactor
Adagrad
Adam
AdamW
@ -177,6 +178,7 @@ Below is a table showing the available and default implementations of each algor
:delim: ;
:class:`Adadelta`;foreach;yes;no
:class:`Adafactor`;for-loop;no;no
:class:`Adagrad`;foreach;yes;yes (cpu only)
:class:`Adam`;foreach;yes;yes
:class:`AdamW`;foreach;yes;yes
@ -198,6 +200,7 @@ Below table is showing the stability status for fused implementations:
:delim: ;
:class:`Adadelta`;unsupported;unsupported;unsupported
:class:`Adafactor`;unsupported;unsupported;unsupported
:class:`Adagrad`;beta;unsupported;unsupported
:class:`Adam`;beta;stable;beta
:class:`AdamW`;beta;stable;beta

View File

@ -1383,7 +1383,6 @@ class TestOptimRenewed(TestCase):
@optims(optim_db, dtypes=[torch.float32])
def test_can_load_older_state_dict(self, device, dtype, optim_info):
new_flags = ["maximize", "foreach", "fused", "differentiable", "capturable"]
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
@ -1417,7 +1416,7 @@ class TestOptimRenewed(TestCase):
old_state_dict = deepcopy(optimizer.state_dict())
old_state_dict_pg = old_state_dict["param_groups"]
for group in old_state_dict_pg:
for flag in new_flags:
for flag in optim_info.not_og_supported_flags:
if flag in group:
del group[flag]

View File

@ -7,6 +7,7 @@ future.
"""
from torch.optim import lr_scheduler, swa_utils
from torch.optim._adafactor import Adafactor
from torch.optim.adadelta import Adadelta
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam

450
torch/optim/_adafactor.py Normal file
View File

@ -0,0 +1,450 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from .optimizer import (
_disable_dynamo_if_unsupported,
_get_scalar_dtype,
_maximize_doc,
Optimizer,
ParamsT,
)
__all__ = ["Adafactor", "adafactor"]
class Adafactor(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-2,
beta2_decay: float = -0.8,
eps: Tuple[Optional[float], float] = (None, 1e-3),
d: float = 1.0,
weight_decay: float = 0.0,
*,
maximize: bool = False,
):
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
if not 0.0 >= beta2_decay:
raise ValueError(f"beta2_decay should be <= 0 but is: {beta2_decay}")
if eps[0] is not None and not 0.0 <= eps[0]:
raise ValueError(f"epsilon1 should be >= 0 but is: {eps[0]}")
if not 0.0 <= eps[1]:
raise ValueError(f"epsilon2 should be >= 0 but is: {eps[1]}")
if not 1.0 <= d:
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
if not 0.0 <= weight_decay:
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
defaults = dict(
lr=lr,
beta2_decay=beta2_decay,
eps=eps,
d=d,
weight_decay=weight_decay,
maximize=maximize,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
step_val = float(p_state["step"])
p_state["step"] = torch.tensor(step_val, dtype=_get_scalar_dtype())
def _init_group(
self,
group,
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
):
for p in group["params"]:
if p.grad is None:
continue
if torch.is_complex(p):
raise RuntimeError("Adafactor does not support complex parameters")
if p.grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients")
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = torch.tensor(0.0, dtype=_get_scalar_dtype())
if p.grad.dim() > 1:
row_shape = list(p.grad.shape)
row_shape[-1] = 1
# Row factor of variance, NOT the same shape as grads (will be reduced along last dim)
state["row_var"] = p.grad.new_zeros(row_shape)
col_shape = list(p.grad.shape)
col_shape[-2] = 1
# Col factor of variance, NOT the same shape as grads (will be reduced along penultimate dim)
state["col_var"] = p.grad.new_zeros(col_shape)
else:
state["variance"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
row_vars.append(state.get("row_var", None))
col_vars.append(state.get("col_var", None))
variances.append(state.get("variance", None))
state_steps.append(state["step"])
return False # has_complex
@torch.no_grad()
def step(self, closure=None):
r"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad: List[Tensor] = []
grads: List[Tensor] = []
row_vars: List[Optional[Tensor]] = []
col_vars: List[Optional[Tensor]] = []
variances: List[Optional[Tensor]] = []
state_steps: List[Tensor] = []
eps1, eps2 = group["eps"]
has_complex = self._init_group(
group,
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
)
adafactor(
params_with_grad,
grads,
row_vars,
col_vars,
variances,
state_steps,
d=group["d"],
lr=group["lr"],
beta2_decay=group["beta2_decay"],
weight_decay=group["weight_decay"],
eps1=eps1,
eps2=eps2,
maximize=group["maximize"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
has_complex=has_complex,
)
return loss
Adafactor.__doc__ = (
r"""Implements Adafactor algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{(lr)}, \: \tau
\text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, \\
&\hspace{15mm} \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
&\hspace{15mm} \: \lambda \text{(weight decay)},
\: \textit{maximize} \\
&\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)}, \\
&\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)}, \\
&\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}G_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}G_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau} \\
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}}) \\
&\hspace{5mm}\alpha_t \leftarrow max(\epsilon_2,
\text{RMS}(\theta_{t-1}))\rho_t \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1: \\
&\hspace{10mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m \\
&\hspace{10mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t) \\
&\hspace{10mm}\widehat{V}_t \leftarrow
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
(1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t) \\
&\hspace{5mm}U_t \leftarrow
\frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
&\hspace{5mm}\widehat{U}_t \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
&\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}
For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
"""
+ rf"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all.
Deviating from the paper, this implementation uses lr for applying weight
decay and as the maximum value for relative step size rho_t. Note that in
the paper, a constant of 0.01 is used as the maximum value for relative
step size, and so we set 0.01 as the default value. (default: 1e-2)
beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
to the coefficient used for computing the running average of the gradient
squared. (default: -0.8)
eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
of the update calculation to improve numerical stability. This use of epsilon1
deviates from the algorithm written in the paper! See note below for more details.
epsilon2 is the term used to avoid having too small a weight update when applying
parameter scaling. (default: (None, 1e-3))
d (float, optional): the clipping threshold, used to avoid larger-than-desired
updates.
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
{_maximize_doc}"""
+ r"""
.. Note::
The implementation of Adafactor subtly differs from Shazeer, Noam, and Mitchell Stern
and implementations in some other frameworks with its use of learning rate and
:math:`\epsilon_1`.
Regarding the learning rate hyperparameter: Shazeer, Noam, and Mitchell Stern do not
use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
affect the step size.
This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:
.. math::
\begin{aligned}
&\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
\end{aligned}
This differs from Shazeer, Noam, and Mitchell Stern, who use a constant of 0.01 as
the maximum value of :math:`\rho_t`
.. math::
\begin{aligned}
&\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
\end{aligned}
Shazeer, Noam, and Mitchell Stern do not enforce an opinion on how weight decay should
be computed, and so we use the learning rate as a coefficient for decoupled weight
decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.
Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
presumed intention of Shazeer, Noam, and Mitchell Stern to use :math:`\epsilon_1` as
a stabilizing term when the squared gradient becomes small.
This stabilization can be written as
.. math::
\begin{aligned}
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m \\
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m) \\
&\hspace{5mm}\widehat{V}_t \leftarrow
\frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)} \\
&\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)} \\
\end{aligned}
where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
are left alone, and we apply :math:`\epsilon_1` at the final calculation of
the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.
This is in contrast to Shazeer, Noam, and Mitchell Stern and other frameworks which
apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
not in the calculations after:
.. math::
\begin{aligned}
&\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
(1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m \\
&\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
(1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \\
&\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t} \\
&\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}} \\
\end{aligned}
.. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
https://arxiv.org/pdf/1804.04235
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
"""
)
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
# If grad is 1-dimensional (aka a vector), there is no factorization necessary
# so row_var and col_var will be None while variance will be filled.
# Contrarily, for a grad with multiple dimensions, we will factor along the last
# 2 dimensions, and so row_var and col_var will be filled and variance will be None.
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
d: float,
lr: Union[Tensor, float],
beta2_decay: float,
weight_decay: float,
eps1: Optional[float],
eps2: float,
maximize: bool,
has_complex: bool,
):
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
step_t = state_steps[i]
row_var = row_vars[i]
col_var = col_vars[i]
variance = variances[i]
if eps1 is None:
eps1 = torch.finfo(param.dtype).eps
# update step
step_t += 1
step_float = step_t.item()
beta2_t = 1 - step_float**beta2_decay
rho_t = min(lr, 1 / (step_float**0.5))
alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t
# Perform stepweight decay
if weight_decay != 0:
param.mul_(1 - lr * weight_decay)
if grad.dim() > 1:
assert (
row_var is not None and col_var is not None
), "row_var and col_var should be defined when grad is multidimensional"
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_mean = (
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
)
row_var.lerp_(row_mean, 1 - beta2_t)
# same as (g * g).mean(dim=-2) w/o materializing an intermediate size g
col_mean = (
torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2))
)
col_var.lerp_(col_mean, 1 - beta2_t)
var_estimate = row_var @ col_var
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
else:
assert (
variance is not None
), "variance should be defined when grad is a vector"
grad_squared = grad * grad
variance.lerp_(grad_squared, 1 - beta2_t)
# avoid writing into variance during update
var_estimate = variance.clone()
# square the eps1 as we sqrt after to keep eps1's magnitude
update = var_estimate.clamp_(min=eps1 * eps1).rsqrt_()
update.mul_(grad)
denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))
param.add_(update, alpha=-alpha / denom)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adafactor)
def adafactor(
params: List[Tensor],
grads: List[Tensor],
row_vars: List[Optional[Tensor]],
col_vars: List[Optional[Tensor]],
variances: List[Optional[Tensor]],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
d: float,
lr: Union[float, Tensor],
beta2_decay: float,
weight_decay: float,
eps1: float,
eps2: float,
maximize: bool,
):
r"""Functional API that performs Adafactor algorithm computation.
See :class:`~torch.optim.Adafactor` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"`state_steps` argument must contain a list of singleton tensors"
)
func = _single_tensor_adafactor
func(
params,
grads,
row_vars,
col_vars,
variances,
state_steps,
d=d,
lr=lr,
beta2_decay=beta2_decay,
weight_decay=weight_decay,
eps1=eps1,
eps2=eps2,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)

View File

@ -13,6 +13,7 @@ from torch import Tensor
from torch.nn import Parameter
from torch.optim import (
Adadelta,
Adafactor,
Adagrad,
Adam,
Adamax,
@ -119,7 +120,15 @@ class OptimizerInfo:
),
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
supported_impls: Tuple[str] = ("foreach", "differentiable"),
supported_impls: Tuple[str, ...] = ("foreach", "differentiable"),
# A subset of all flags, signifying which ones were only supported after the
# original optimizer had already been released. aka impls where we need to check BC.
not_og_supported_flags: Tuple[str, ...] = (
"foreach",
"differentiable",
"maximize",
"capturable",
),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
@ -139,12 +148,13 @@ class OptimizerInfo:
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
optim_error_inputs_func=None, # Function to generate optim inputs that error
supports_fused_on: Tuple[str] = (),
supports_fused_on: Tuple[str, ...] = (),
):
self.optim_cls = optim_cls
self.optim_inputs_func = optim_inputs_func
self.scheduler_inputs = scheduler_inputs
self.supported_impls = supported_impls
self.not_og_supported_flags = not_og_supported_flags
self.supports_sparse = supports_sparse
self.metadata_for_sparse = metadata_for_sparse
self.only_supports_sparse_grads = only_supports_sparse_grads
@ -347,6 +357,79 @@ def optim_error_inputs_func_adadelta(device, dtype):
return error_inputs
def optim_inputs_func_adafactor(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "lr": 0.01},
desc="nonzero weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"beta2_decay": -1.0},
desc="non-default beta2_decay",
),
OptimizerInput(
params=None,
kwargs={"d": 1.5},
desc="non-default clipping threshold d",
),
]
def optim_error_inputs_func_adafactor(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
complex_param.grad = torch.rand_like(complex_param)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(eps=(-1e-30, 1e-3)),
desc="epsilon1 should be >= 0",
),
error_type=ValueError,
error_regex="epsilon1 should be >= 0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(d=0.0),
desc="invalid d",
),
error_type=ValueError,
error_regex="Clipping threshold d should be >= 1",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(beta2_decay=0.8),
desc="invalid beta2_decay",
),
error_type=ValueError,
error_regex="beta2_decay should be <= 0",
),
ErrorOptimizerInput(
OptimizerInput(
params=[complex_param],
kwargs=dict(),
desc="does not support complex parameters",
),
error_type=RuntimeError,
error_regex="Adafactor does not support complex parameters",
error_on=OptimizerErrorEnum.STEP_ERROR,
),
]
return error_inputs
def optim_inputs_func_adagrad(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
@ -1170,11 +1253,27 @@ optim_db: List[OptimizerInfo] = [
),
),
),
OptimizerInfo(
Adafactor,
optim_inputs_func=optim_inputs_func_adafactor,
optim_error_inputs_func=optim_error_inputs_func_adafactor,
supported_impls=(),
not_og_supported_flags=(),
supports_complex=False,
skips=(),
),
OptimizerInfo(
Adagrad,
optim_inputs_func=optim_inputs_func_adagrad,
optim_error_inputs_func=optim_error_inputs_func_adagrad,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu",),
supports_sparse=True,
metadata_for_sparse=(
@ -1258,6 +1357,13 @@ optim_db: List[OptimizerInfo] = [
),
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu", "cuda", "mps"),
decorators=(
# Expected floating point error between fused and compiled forloop
@ -1380,6 +1486,13 @@ optim_db: List[OptimizerInfo] = [
optim_inputs_func=optim_inputs_func_adamw,
optim_error_inputs_func=optim_error_inputs_func_adamw,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu", "cuda", "mps"),
decorators=(
# Expected error between compiled forloop and fused optimizers
@ -1779,6 +1892,13 @@ optim_db: List[OptimizerInfo] = [
),
optim_error_inputs_func=optim_error_inputs_func_sgd,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_sparse=True,
metadata_for_sparse=(
{