Files
pytorch/torch/testing/_internal/common_optimizers.py
Chuanhao Zhuge 74280d0913 [muon] Introduce Muon optimizer to PyTorch (#160213)
A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.

This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.

**Usage**

```python
    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = adamw_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()
```

~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~

```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```

As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.

By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.

**Testing**

~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~

As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.

2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />

**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.

As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.

As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = KellySingleDeviceMuon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />

**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />

**Note**
We restrict the implementation to accept only 2D parameters.

An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.

Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.

**Next Steps**

1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.

****

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
2025-08-24 08:03:04 +00:00

2318 lines
84 KiB
Python

# mypy: ignore-errors
import functools
import itertools
import sys
import unittest
from copy import deepcopy
from enum import Enum
from typing import Any, Union
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.optim import (
Adadelta,
Adafactor,
Adagrad,
Adam,
Adamax,
AdamW,
ASGD,
LBFGS,
Muon,
NAdam,
Optimizer,
RAdam,
RMSprop,
Rprop,
SGD,
SparseAdam,
)
from torch.optim.lr_scheduler import (
ConstantLR,
ExponentialLR,
LinearLR,
PolynomialLR,
ReduceLROnPlateau,
StepLR,
)
from torch.testing._internal.common_device_type import tol, toleranceOverride
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_utils import (
_TestParametrizer,
skipIfMPS,
skipIfTorchDynamo,
skipIfXpu,
TEST_WITH_TORCHDYNAMO,
)
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
class OptimizerInput:
"""Contains args / kwargs to be passed to an optimizer constructor."""
__slots__ = ["params", "kwargs", "desc"]
def __init__(
self,
params: Union[
list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
],
kwargs: dict[str, Any],
desc: str = "",
):
# params can be a list of Tensors OR param_groups OR None
self.params = params
self.kwargs = kwargs
self.desc = desc
def __repr__(self):
return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
class OptimizerErrorEnum(Enum):
"""Enumerates when an error is raised when testing optimizers."""
CONSTRUCTION_ERROR = 0
STEP_ERROR = 1
class ErrorOptimizerInput:
"""
An OptimizerInput that will cause the optimizer to throw an error when constructed.
Includes the type and string of the resulting error.
"""
__slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
def __init__(
self,
optimizer_error_input,
*,
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
error_type=RuntimeError,
error_regex="",
):
self.optimizer_error_input = optimizer_error_input
self.error_on = error_on
self.error_type = error_type
self.error_regex = error_regex
class OptimizerInfo:
"""Optimizer information to be used in testing."""
def __init__(
self,
optim_cls: Optimizer, # Class object for the Optimizer under test
*,
# Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
# to the test using the OptimizerInfo. OptimizerInput.params is likely None.
# Can optionally take in device to filter out certain unsupported configs
optim_inputs_func,
# Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
# LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
# We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
# LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
# A few optimizers like SGD and Adam will test more LRSchedulers.
scheduler_inputs=(
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
supported_impls: tuple[str, ...] = ("foreach", "differentiable"),
# A subset of all flags, signifying which ones were only supported after the
# original optimizer had already been released. aka impls where we need to check BC.
not_og_supported_flags: tuple[str, ...] = (
"foreach",
"differentiable",
"maximize",
"capturable",
),
# the optim supports passing in sparse gradients as well as dense grads
supports_sparse: bool = False,
# the optimizer constructor supports passing in capturable as a kwarg
has_capturable_arg: bool = False,
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
only_supports_sparse_grads: bool = False,
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
# with especially tuned hyperparameters. These only apply if the optimizer supports
# sparse parameters or grads.
metadata_for_sparse=({}, []),
# the optim supports complex parameters
supports_complex: bool = True,
# whether the optimizer.step() function requires a closure to be passed
step_requires_closure: bool = False,
# whether the optimizer supports per-param options with parameter groups
supports_param_groups: bool = True,
# whether the optimizer supports parameters on multiple devices
supports_multiple_devices: bool = True,
skips=(), # Indicates which tests to skip
decorators=None, # Additional decorators to apply to generated tests
optim_error_inputs_func=None, # Function to generate optim inputs that error
supports_fused_on: tuple[str, ...] = (),
):
self.optim_cls = optim_cls
self.optim_inputs_func = optim_inputs_func
self.scheduler_inputs = scheduler_inputs
self.supported_impls = supported_impls
self.not_og_supported_flags = not_og_supported_flags
self.supports_sparse = supports_sparse
self.has_capturable_arg = has_capturable_arg
self.metadata_for_sparse = metadata_for_sparse
self.only_supports_sparse_grads = only_supports_sparse_grads
self.supports_complex = supports_complex
self.step_requires_closure = step_requires_closure
self.supports_param_groups = supports_param_groups
self.supports_multiple_devices = supports_multiple_devices
self.decorators = (
*(decorators if decorators else []),
*(skips if skips else []),
)
self.optim_error_inputs_func = optim_error_inputs_func
self.supports_fused_on = supports_fused_on
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
result = []
for decorator in self.decorators:
if isinstance(decorator, DecorateInfo):
if decorator.is_active(
test_class, test_name, device, dtype, param_kwargs
):
result.extend(decorator.decorators)
else:
result.append(decorator)
return result
@property
def name(self):
return self.optim_cls.__name__
class optims(_TestParametrizer):
"""Decorator for specifying a list of optimizers over which to run a test."""
def __init__(self, optim_info_iterable, dtypes=None):
self.optim_info_list = list(optim_info_iterable)
# optimizers aren't limited to be one dtype as parameters can have different dtypes
# We default to torch.float32, but dtypes should be specified through passed in
# parameters.
self.dtypes = dtypes if dtypes is not None else [torch.float32]
def _parametrize_test(self, test, generic_cls, device_cls):
if device_cls is None:
raise RuntimeError(
"The @optims decorator is only intended to be used in a device-specific "
"context; use it with instantiate_device_type_tests() instead of "
"instantiate_parametrized_tests()"
)
for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
# Construct the test name; device / dtype parts are handled outside.
# See [Note: device and dtype suffix placement]
test_name = optim_info.name
# Construct parameter kwargs to pass to the test.
param_kwargs = {"optim_info": optim_info, "dtype": dtype}
try:
@functools.wraps(test)
def test_wrapper(*args, **kwargs):
return test(*args, **kwargs)
decorator_fn = functools.partial(
optim_info.get_decorators,
generic_cls.__name__,
test.__name__,
device_cls.device_type,
dtype,
)
yield (test_wrapper, test_name, param_kwargs, decorator_fn)
except Exception as ex:
# Provides an error message for debugging before rethrowing the exception
print(
f"Failed to instantiate {test_name} for module {optim_info.name}!"
)
raise ex
# Helper function for generating error inputs for all optimizers, used below.
def get_error_inputs_for_all_optims(device, dtype):
if _get_device_type(device) == "cpu":
# Creating 2D parameters for compatibility with Muon.
sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
return [
ErrorOptimizerInput(
OptimizerInput(
params=sample_param,
kwargs={},
desc="invalid param type",
),
error_type=TypeError,
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
),
ErrorOptimizerInput(
OptimizerInput(
params=[sample_param, sample_param],
kwargs={},
desc="a param group cannot have duplicate parameters",
),
error_type=UserWarning,
error_regex=".*a parameter group with duplicate parameters.*",
),
ErrorOptimizerInput(
OptimizerInput(
params=[{"params": sample_param}, {"params": sample_param}],
kwargs={},
desc="duplicate parameters should not occur across param groups either",
),
error_type=ValueError,
error_regex="some parameters appear in more than one parameter group",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=torch.tensor([0.001, 0.001])),
desc="Tensor lr must be 1-element",
),
error_type=ValueError,
error_regex="Tensor lr must be 1-element",
),
ErrorOptimizerInput(
OptimizerInput(
params=[("weight", sample_param), sample_param2],
kwargs={},
desc="all optimizer params should be with/without names",
),
error_type=ValueError,
error_regex="all optimizer params should be with/without names. Some param names are missing",
),
ErrorOptimizerInput(
OptimizerInput(
params=[
{"params": [sample_param], "lr": 1e-2},
{"params": [("weight", sample_param2)]},
],
kwargs={},
desc="all optimizer param groups should be with/without names.",
),
error_type=ValueError,
error_regex="all optimizer param groups should be with/without names. "
"cannot add param group with names to the optimizer",
),
]
else:
return []
# ------------------------------------------------------------------------------------------
# NOTE: [optimizer kwarg categories]
# We categorize optimizer kwargs as 3 types:
# 1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
# algorithms and thus only show up for certain optimizers. There are many of these, so I
# do not bother gathering them all and listing them here. The converse to these would be
# global flags that every optimizer ideally _should_ support. We break global flags into
# 2 further categories and list them all below.
# 2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
# global-friendly flags are global flags who play nicely with all other global flags,
# i.e., are mutually exclusive in function. This means that any pair of the following
# flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
# following flags theoretically can be enabled with ANY other global flag, including the
# cliquey ones (e.g, capturable and foreach).
# 3. global-cliquey = ["foreach", "fused", "differentiable"]
# global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
# usually because they contradict each other in function. For example, one should not flip
# both foreach AND fused to True, because they are two differing performance optimizations
# in which you can only opt into one.
#
# The following optim_inputs_func_* sampling functions only return constructor combinations of
# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
# well with additional kwargs. On the flip side of the same coin, we reserve setting the
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
def optim_inputs_func_adadelta(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "capturable": True},
desc="capturable with weight decay",
),
OptimizerInput(
params=None,
kwargs={"lr": torch.tensor(0.001), "capturable": True},
desc="Tensor lr with capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),
OptimizerInput(
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_adadelta(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, rho=1.1),
desc="rho should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid rho value: 1.1",
),
]
return error_inputs
def optim_inputs_func_adafactor(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "lr": 0.01},
desc="nonzero weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"beta2_decay": -1.0},
desc="non-default beta2_decay",
),
OptimizerInput(
params=None,
kwargs={"d": 1.5},
desc="non-default clipping threshold d",
),
]
def optim_error_inputs_func_adafactor(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
complex_param.grad = torch.rand_like(complex_param)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(eps=(-1e-30, 1e-3)),
desc="epsilon1 should be >= 0",
),
error_type=ValueError,
error_regex="epsilon1 should be >= 0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(d=0.0),
desc="invalid d",
),
error_type=ValueError,
error_regex="Clipping threshold d should be >= 1",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(beta2_decay=0.8),
desc="invalid beta2_decay",
),
error_type=ValueError,
error_regex="beta2_decay should be <= 0",
),
ErrorOptimizerInput(
OptimizerInput(
params=[complex_param],
kwargs=dict(),
desc="does not support complex parameters",
),
error_type=RuntimeError,
error_regex="Adafactor does not support complex parameters",
error_on=OptimizerErrorEnum.STEP_ERROR,
),
]
return error_inputs
def optim_inputs_func_adagrad(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
OptimizerInput(
params=None,
kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
desc="initial_accumulator_value",
),
OptimizerInput(
params=None,
kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
desc="lr_decay",
), # TODO: Move out to testing in param_group?
OptimizerInput(
params=None,
kwargs={"lr": torch.tensor(0.001)},
desc="Tensor lr",
),
]
def optim_error_inputs_func_adagrad(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, lr_decay=-0.5),
desc="lr_decay must be bigger than 0",
),
error_type=ValueError,
error_regex="Invalid lr_decay value: -0.5",
),
]
return error_inputs
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
# with all implementation code paths...
def optim_inputs_func_adam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
desc="capturable, amsgrad",
),
OptimizerInput(
params=None,
kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
desc="Tensor lr with capturable and amsgrad",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
"amsgrad": True,
"capturable": True,
},
desc="Tensor lr, Tensor betas, with capturable and amsgrad",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
"amsgrad": False,
"capturable": True,
},
desc="Tensor lr, Tensor betas, with capturable",
),
]
mps_supported_configs = [
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
),
]
total = (
[
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "amsgrad": True},
desc="amsgrad",
),
]
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
)
if dtype in (torch.float16,):
for input in total:
"""
Too small eps will make denom to be zero for low precision dtype
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
For example,
>>> a
tensor([0.], dtype=torch.float16)
>>> a + 1e-8
tensor([0.], dtype=torch.float16)
"""
input.kwargs["eps"] = 0.1
return total
def optim_error_inputs_func_adam(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
desc="lr as Tensor doesn't work with foreach & not capturable",
),
error_type=ValueError,
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))),
desc="betas must be either both floats or both Tensors",
),
error_type=ValueError,
error_regex="betas must be either both floats or both Tensors",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)),
desc="betas must be either both floats or both Tensors",
),
error_type=ValueError,
error_regex="betas must be either both floats or both Tensors",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(
lr=1e-2,
betas=(torch.tensor(0.9), torch.tensor(0.99)),
foreach=True,
),
desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
),
error_type=ValueError,
error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
),
]
if _get_device_type(device) == "cuda":
sample_tensor = torch.empty((), device=device, dtype=dtype)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=[sample_tensor],
kwargs={"foreach": True, "fused": True},
desc="`fused` and `foreach` cannot be `True` together",
),
error_type=RuntimeError,
error_regex="`fused` and `foreach` cannot be `True` together",
),
ErrorOptimizerInput(
OptimizerInput(
params=[sample_tensor],
kwargs={"fused": True, "differentiable": True},
desc="`fused` does not support `differentiable`",
),
error_type=RuntimeError,
error_regex="`fused` does not support `differentiable`",
),
]
return error_inputs
def optim_inputs_func_adamax(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
desc="capturable, maximize, weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
desc="capturable, maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
desc="capturable, weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"weight_decay": 0.9,
"maximize": False,
"capturable": True,
},
desc="capturable, weight_decay, tensor LR",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={"maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, weight_decay",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_adamax(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
desc="beta2 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 1: 1.0",
),
]
return error_inputs
def optim_inputs_func_adamw(device, dtype=None):
return optim_inputs_func_adam(device, dtype)
def optim_error_inputs_func_adamw(device, dtype):
return optim_error_inputs_func_adam(device, dtype)
def optim_inputs_func_asgd(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"maximize": True, "capturable": True},
desc="maximize, capturable",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "capturable": True},
desc="weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
desc="maximize, weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"weight_decay": 0.1,
"maximize": True,
"capturable": True,
},
desc="maximize, weight_decay, capturable, tensor LR",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize, nonzero weight_decay",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_asgd(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-0.5),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -0.5",
),
]
return error_inputs
def optim_inputs_func_lbfgs(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
),
OptimizerInput(
params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
),
OptimizerInput(
params=None,
kwargs={"line_search_fn": "strong_wolfe"},
desc="strong_wolfe",
),
]
def optim_error_inputs_func_lbfgs(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
return error_inputs
def optim_inputs_func_muon(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.2},
desc="non-default weight_decay",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.8},
desc="non-default momentum",
),
OptimizerInput(
params=None,
kwargs={"ns_steps": 6},
desc="passing alternative ns_steps",
),
OptimizerInput(
params=None,
kwargs={
"ns_coefficients": (3.4, -4.7, 2.0),
},
desc="passing alternative ns_coefficients",
),
]
def optim_error_inputs_func_muon(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
complex_param.grad = torch.rand_like(complex_param)
non_2d_param = torch.rand(2, 3, 4, device=device, dtype=dtype)
non_2d_param.grad = torch.rand_like(non_2d_param)
param = torch.rand(2, 3, device=device, dtype=dtype)
param.grad = torch.rand_like(param)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=[non_2d_param],
kwargs=dict(),
desc="only support 2D parameters",
),
error_type=ValueError,
error_regex="Muon only supports 2D parameters",
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
),
ErrorOptimizerInput(
OptimizerInput(
params=[param],
kwargs={"adjust_lr_fn": "arbitrary"},
desc="only support `original` and `match_rms_adamw`",
),
error_type=ValueError,
error_regex="Adjust learning rate function arbitrary is not supported",
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
),
ErrorOptimizerInput(
OptimizerInput(
params=[complex_param],
kwargs=dict(),
desc="does not support complex parameters",
),
error_type=RuntimeError,
error_regex="Muon does not support complex parameters",
error_on=OptimizerErrorEnum.STEP_ERROR,
),
]
return error_inputs
def optim_inputs_func_nadam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
desc="weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.9,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
"capturable": True,
},
desc="decoupled_weight_decay, capturable",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"weight_decay": 0.9,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
"capturable": True,
},
desc="decoupled_weight_decay, capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
OptimizerInput(
params=None,
kwargs={"momentum_decay": 6e-3},
desc="non-zero momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
},
desc="weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
desc="weight_decay, momentum_decay",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"momentum_decay": 6e-3,
"decoupled_weight_decay": True,
},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_nadam(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum_decay=-0.2),
desc="momentum_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid momentum_decay value: -0.2",
),
]
return error_inputs
# Weird story bro, NAdam and RAdam do not have maximize.
def optim_inputs_func_radam(device=None, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={
"capturable": True,
"weight_decay": 0.1,
},
desc="capturable, weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"capturable": True,
"weight_decay": 0.1,
"decoupled_weight_decay": True,
},
desc="capturable, weight_decay, decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={
"lr": torch.tensor(0.001),
"capturable": True,
"weight_decay": 0.1,
"decoupled_weight_decay": True,
},
desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
desc="decoupled_weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_radam(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, weight_decay=-1),
desc="weight_decay should > 0",
),
error_type=ValueError,
error_regex="Invalid weight_decay value: -1",
),
]
return error_inputs
def optim_inputs_func_rmsprop(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
desc="capturable, maximize",
),
OptimizerInput(
params=None,
kwargs={"lr": torch.tensor(0.001), "capturable": True},
desc="Tensor lr with capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True},
desc="centered",
),
OptimizerInput(
params=None,
kwargs={
"maximize": True,
"weight_decay": 0.1,
},
desc="maximize, weight_decay",
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
desc="momentum",
),
OptimizerInput(
params=None,
kwargs={
"weight_decay": 0.1,
"centered": True,
"momentum": 0.1,
"maximize": True,
},
desc="maximize, centered, weight_decay, w/ momentum",
),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_rmsprop(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-1.0),
desc="momentum should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid momentum value: -1.0",
),
]
return error_inputs
def optim_inputs_func_rprop(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
OptimizerInput(
params=None,
kwargs={"lr": torch.tensor(0.001), "capturable": True},
desc="Tensor lr with capturable",
),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
),
OptimizerInput(
params=None,
kwargs={"step_sizes": (2e-6, 100)},
desc="non-default step_sizes",
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
def optim_error_inputs_func_rprop(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
desc="0 < eta1 < 1 < eta2",
),
error_type=ValueError,
error_regex="Invalid eta values: 1.0, 0.5",
),
]
return error_inputs
def optim_inputs_func_sgd(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
),
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.1, "maximize": True},
desc="maximize",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "dampening": 0.5},
desc="dampening",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "weight_decay": 0.1},
desc="weight_decay w/ momentum",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
desc="nesterov",
),
]
def optim_error_inputs_func_sgd(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, momentum=-0.5),
desc="momentum should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid momentum value: -0.5",
),
]
return error_inputs
def optim_inputs_func_sparseadam(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(
params=None, kwargs={"lr": 0.01}, desc="non-default lr"
), # TODO: Move out to testing in param_group?
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
]
def optim_error_inputs_func_sparseadam(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
if _get_device_type(device) == "cpu":
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=None,
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
desc="beta1 should be between 0 and 1",
),
error_type=ValueError,
error_regex="Invalid beta parameter at index 0: 1.0",
),
ErrorOptimizerInput(
OptimizerInput(
params=[
torch.zeros(
3, layout=torch.sparse_coo, device=device, dtype=dtype
)
],
kwargs={},
desc="dense params required",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
ErrorOptimizerInput(
OptimizerInput(
params=[
{
"params": [
torch.zeros(
3,
layout=torch.sparse_coo,
device=device,
dtype=dtype,
)
]
}
],
kwargs={},
desc="dense params required in param_groups",
),
error_type=ValueError,
error_regex="SparseAdam requires dense parameter tensors",
),
ErrorOptimizerInput(
OptimizerInput(
params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
kwargs={},
desc="complex not supported",
),
error_type=ValueError,
error_regex="SparseAdam does not support complex parameters",
),
]
return error_inputs
def _get_device_type(device: Union[str, torch.device]) -> str:
# Returns the device type as a string, e.g., "cpu" or "cuda"
if isinstance(device, torch.device):
device = str(device.type)
assert isinstance(device, str)
return device.split(":")[0]
def _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=()
) -> list[OptimizerInput]:
"""
Return a list of all configs for a given optimizer as a list of OptimizerInputs,
including configs that have supported global cliquey kwargs (foreach, fused,
differentiable) based on optim_info.supported_impls.
The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
intentionally do NOT include global cliquey kwargs to give flexibility to tests.
For example, testing correctness between toggling foreach on and off is now
trivial. That said, we sometimes want to test for all possible configs on an
optimizer including all supported flags, so this helper returns all optim inputs.
"""
assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
"skip must be a subset of ['foreach', 'fused', 'differentiable']"
)
optim_inputs = optim_info.optim_inputs_func(device)
supported_impls = tuple(
x
for x in optim_info.supported_impls
if x not in skip
and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
and (
_get_device_type(device) in _get_foreach_kernels_supported_devices()
or x != "foreach"
)
)
all_optim_inputs = []
for optim_input in optim_inputs:
# Add the base config where all the flags are False
base_kwargs = deepcopy(optim_input.kwargs)
if len(supported_impls) != 0:
for flag in supported_impls:
base_kwargs[flag] = False
all_optim_inputs.append(
OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
)
else:
all_optim_inputs.append(optim_input)
# Add a config for when each of the global cliquey kwargs is True
# Note that in [optimizer kwarg categories], these kwargs are mutually
# exclusive, so we do not need to product them together.
for flag in supported_impls:
new_kwargs = deepcopy(base_kwargs)
new_kwargs[flag] = True
all_optim_inputs.append(
OptimizerInput(
params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
)
)
return all_optim_inputs
# Database of OptimizerInfo entries in alphabetical order.
optim_db: list[OptimizerInfo] = [
OptimizerInfo(
Adadelta,
optim_inputs_func=optim_inputs_func_adadelta,
optim_error_inputs_func=optim_error_inputs_func_adadelta,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
# Note on tolerances:
# test_correctness_Adadelta_cuda_float32
# Mismatched elements: 10 / 100 (10.0%)
# Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
# Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
# This is due to floating point ordering error + usage of sqrt
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(
rtol=5.5e-4,
atol=5e-5,
)
}
),
"CompiledOptimizerParityTests",
"test_correctness",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
Adafactor,
optim_inputs_func=optim_inputs_func_adafactor,
optim_error_inputs_func=optim_error_inputs_func_adafactor,
supported_impls=("foreach",),
not_og_supported_flags=("foreach",),
supports_complex=False,
skips=(
DecorateInfo(
unittest.skip("See #133268 regarding dtype being None"),
"CompiledOptimizerParityTests",
"test_correctness",
device_type="cuda",
active_if=lambda kwargs: kwargs.get("use_closure", False),
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_can_load_older_state_dict",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_foreach_large_tensor",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_foreach_matches_forloop",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_load_nontensor_step",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_mixed_device_dtype",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_param_groups_lr",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_param_groups_weight_decay",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_peak_memory_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_save_load_equality_with_weights_only",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #116028 regarding copy not supported"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_state_dict_deterministic",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
device_type="cuda",
),
DecorateInfo(
unittest.skip("See #133268 regarding dtype being None"),
"CompiledOptimizerParityTests",
"test_correctness",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_can_load_older_state_dict",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_load_nontensor_step",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_param_groups_lr",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_param_groups_weight_decay",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_save_load_equality_with_weights_only",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_state_dict_deterministic",
device_type="xpu",
),
DecorateInfo(
skipIfTorchDynamo("See #133268 regarding dtype being None"),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
device_type="xpu",
),
),
),
OptimizerInfo(
Adagrad,
optim_inputs_func=optim_inputs_func_adagrad,
optim_error_inputs_func=optim_error_inputs_func_adagrad,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu",),
supports_sparse=True,
metadata_for_sparse=(
{"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
[
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
],
),
decorators=(
DecorateInfo(
# Note on tolerances:
# difference comes from the fact that the non fused kernel have
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
# to make sure there is no discrepancies between cuda fused kernel
# and cpu fused kernel
toleranceOverride(
{
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
torch.float16: tol(atol=5e-3, rtol=5e-3),
}
),
"TestOptimRenewed",
"test_fused_matches_forloop",
),
),
skips=(
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
Adam,
optim_inputs_func=optim_inputs_func_adam,
scheduler_inputs=(
[lambda opt: ExponentialLR(opt, gamma=0.9)],
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
optim_error_inputs_func=optim_error_inputs_func_adam,
supported_impls=("foreach", "differentiable", "fused"),
has_capturable_arg=True,
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu", "cuda", "mps"),
decorators=(
# Expected floating point error between fused and compiled forloop
DecorateInfo(
toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
"TestOptimRenewed",
"test_fused_matches_forloop",
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
and kwargs["dtype"] == torch.float64,
),
DecorateInfo(
# Note on tolerances:
# difference comes from the fact that the non fused kernel have
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
# to make sure there is no discrepancies between cuda fused kernel
# and cpu fused kernel
toleranceOverride(
{
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
torch.float16: tol(atol=5e-3, rtol=5e-3),
}
),
"TestOptimRenewed",
"test_fused_matches_forloop",
),
DecorateInfo(
# Note on tolerances:
# Tracking through #127000
toleranceOverride(
{
torch.float32: tol(atol=3e-5, rtol=1.3e-06),
}
),
"TestCudaOptims",
"test_grad_scaling_autocast_fused_optimizers",
),
),
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
Adamax,
optim_inputs_func=optim_inputs_func_adamax,
optim_error_inputs_func=optim_error_inputs_func_adamax,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
unittest.skip("Uses too much memory, even for H100, surprisingly."),
"TestOptimRenewed",
"test_foreach_large_tensor",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
AdamW,
optim_inputs_func=optim_inputs_func_adamw,
optim_error_inputs_func=optim_error_inputs_func_adamw,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_fused_on=("cpu", "cuda", "mps"),
has_capturable_arg=True,
decorators=(
# Expected error between compiled forloop and fused optimizers
DecorateInfo(
toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
"TestOptimRenewed",
"test_fused_matches_forloop",
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
and kwargs["dtype"] == torch.float64,
),
DecorateInfo(
toleranceOverride(
# Note on tolerances:
# difference comes from the fact that the non fused kernel have
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
# to make sure there is no discrepancies between cuda fused kernel
# and cpu fused kernel
{
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
torch.float16: tol(atol=5e-3, rtol=5e-3),
}
),
"TestOptimRenewed",
"test_fused_matches_forloop",
),
# Note on tolerances:
# Tracking through #127000
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(
atol=3e-5,
rtol=1.3e-06,
)
}
),
"TestCudaOptims",
"test_grad_scaling_autocast_fused_optimizers",
),
),
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
ASGD,
optim_inputs_func=optim_inputs_func_asgd,
optim_error_inputs_func=optim_error_inputs_func_asgd,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
toleranceOverride(
{
torch.float32: tol(atol=1.5e-5, rtol=1e-5),
}
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
DecorateInfo(
unittest.skip(
"ASGD internally changes the weights even with zero grad"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
),
),
OptimizerInfo(
LBFGS,
optim_inputs_func=optim_inputs_func_lbfgs,
optim_error_inputs_func=optim_error_inputs_func_lbfgs,
supported_impls=(),
step_requires_closure=True,
supports_param_groups=False,
supports_multiple_devices=False,
skips=(
# Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
DecorateInfo(
skipIfMPS,
"TestOptimRenewed",
"test_can_load_older_state_dict",
device_type="mps",
),
DecorateInfo(
toleranceOverride(
{
torch.complex64: tol(
rtol=4.5e-5,
atol=5e-5,
)
}
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_groups_weight_decay",
),
DecorateInfo(
unittest.skip("LBFGS doesn't support multidevice"),
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
unittest.skip("Does not support param groups"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
# https://github.com/pytorch/pytorch/issues/131398
DecorateInfo(
unittest.expectedFailure,
"CompiledOptimizerParityTests",
"test_correctness",
active_if=lambda kwargs: sys.platform == "darwin"
and kwargs["use_closure"],
),
),
),
OptimizerInfo(
Muon,
optim_inputs_func=optim_inputs_func_muon,
optim_error_inputs_func=optim_error_inputs_func_muon,
supported_impls=(),
not_og_supported_flags=(),
supports_complex=False,
skips=(
# Note on tolerances:
# test_correctness_Muon_use_closure_True_cuda_float32
# Mismatched elements: 2 / 100 (2.0%)
# Greatest absolute difference: 0.0006124898791313171 at index (2, 1) (up to 0.0002 allowed)
# Greatest relative difference: 0.026825083419680595 at index (2, 6) (up to 0.01 allowed)
# This is due compile uses addmm for matmul in the orthogonalization function,
# creating a small numerical difference compared to the plain matmul op used in eager.
DecorateInfo(
toleranceOverride(
{
torch.float: tol(
rtol=0.08,
atol=0.001,
),
}
),
"CompiledOptimizerParityTests",
"test_correctness",
),
),
),
OptimizerInfo(
NAdam,
optim_inputs_func=optim_inputs_func_nadam,
optim_error_inputs_func=optim_error_inputs_func_nadam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors, https://github.com/pytorch/pytorch/issues/117150"
),
"TestOptimRenewed",
"test_load_nontensor_step",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
RAdam,
optim_inputs_func=optim_inputs_func_radam,
optim_error_inputs_func=optim_error_inputs_func_radam,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
toleranceOverride(
{
# previously atol=1e-7, rtol=1e-7
torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
}
),
"TestOptimRenewed",
"test_foreach_matches_forloop",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
RMSprop,
optim_inputs_func=optim_inputs_func_rmsprop,
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
toleranceOverride(
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
torch.float32: tol(atol=5e-04, rtol=0.01),
}
),
"TestOptimRenewed",
"test_mixed_device_dtype",
active_if=TEST_WITH_TORCHDYNAMO,
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
Rprop,
optim_inputs_func=optim_inputs_func_rprop,
optim_error_inputs_func=optim_error_inputs_func_rprop,
supported_impls=("foreach", "differentiable"),
has_capturable_arg=True,
skips=(
DecorateInfo(
skipIfTorchDynamo("See #116028"),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
SGD,
optim_inputs_func=optim_inputs_func_sgd,
scheduler_inputs=(
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
),
optim_error_inputs_func=optim_error_inputs_func_sgd,
supported_impls=("foreach", "differentiable", "fused"),
not_og_supported_flags=(
"foreach",
"differentiable",
"fused",
"maximize",
"capturable",
),
supports_sparse=True,
metadata_for_sparse=(
{
"lr": 4.8e-3,
"maximize": False,
"momentum": 0,
"nesterov": False,
"weight_decay": 0,
},
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
),
supports_fused_on=(
"cpu",
"cuda",
"mps",
),
skips=(
DecorateInfo(
skipIfTorchDynamo(
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
),
"TestOptimRenewed",
"test_set_default_dtype_works_with_foreach",
),
DecorateInfo(
skipIfTorchDynamo(
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
),
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
toleranceOverride(
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
torch.float32: tol(atol=5e-04, rtol=0.007),
}
),
"TestOptimRenewed",
"test_mixed_device_dtype",
active_if=TEST_WITH_TORCHDYNAMO,
),
DecorateInfo(
skipIfTorchDynamo(
"This test uses mocks, which dynamo does not support"
),
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
),
),
OptimizerInfo(
SparseAdam,
optim_inputs_func=optim_inputs_func_sparseadam,
optim_error_inputs_func=optim_error_inputs_func_sparseadam,
supported_impls=(),
only_supports_sparse_grads=True,
metadata_for_sparse=({"lr": 4e-2}, []),
supports_complex=False, # Missing complex support, see #118153
skips=(
DecorateInfo(
skipIfMPS, # SparseAdam does not support MPS
"TestOptimRenewed",
device_type="mps",
),
DecorateInfo(
skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_param_groups_lr",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_tensor_lr",
),
DecorateInfo(
unittest.skip(
"SparseAdam does not support dense gradients, see #116507"
),
"TestOptimRenewed",
"test_can_load_older_state_dict",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_load_nontensor_step",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_forloop_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_param_group_with_lrscheduler_goes_right_direction",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_state_dict_with_cuda_params",
),
DecorateInfo(
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
),
),
]
class TensorTracker:
"""
A utility to track tensor clones in a list, with the expectation of popping them later (in
order) to make fair comparisons between two multi-step computation. The intended use case is
usually when comparing two supposed equal computations, such as an optimizer step that each
individually consists of multiple steps, where numerical deviation could multiply.
The goal is to be able to compare and align numbers at every milestone so as to minimize
numerical discrepancies, and so when the test fails, it is likely a real problem.
"""
def __init__(self, assert_eq_kwargs=None):
if assert_eq_kwargs is None:
assert_eq_kwargs = {}
self.assert_eq_kwargs = assert_eq_kwargs
self.tensors = []
def add(self, tensor):
"""
Add a detach().clone()'d version of the tensor
"""
self.tensors.append(tensor.detach().clone())
# pops from beginning, like a queue and not a stack!
def pop_check_set(self, tensor_to_set, testcase):
"""
Pop the first element in the tensor tracker, assert equality between the popped tensor and
the input tensor, and then set the input tensor to have the same values as the popped tensor
(with copy_).
"""
testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
ref = self.tensors.pop(0)
testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
with torch.no_grad():
tensor_to_set.copy_(ref)
def all_popped(self):
return len(self.tensors) == 0