mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
A performance optimization. Using `torch.addmm`, which fuses `matrix multiply + scale + add` into one op. **Benchmark** In a QWEN-like 0.5B model training we observed average `optimizer.step()` latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a **1.62×** speedup. matmul <img width="1403" height="600" alt="Screenshot 2025-08-24 at 3 15 37 PM" src="https://github.com/user-attachments/assets/a77a68d4-da3c-473a-97f0-e6ef0a3b46d9" /> addmm <img width="1426" height="602" alt="Screenshot 2025-08-24 at 3 13 42 PM" src="https://github.com/user-attachments/assets/e493af36-44d3-4026-9f7c-fd0f9cdbc7e5" /> **Testing** End-to-end training: We used a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm. <img width="1035" height="434" alt="Screenshot 2025-08-24 at 2 56 21 PM" src="https://github.com/user-attachments/assets/b96b13e3-0a01-4908-853c-d917b41f3d75" /> Unit test: ```python # dummy model and data model0 = Linear(10, 10, bias=False) model1 = copy.deepcopy(model0) inputs = torch.randn(8, 10) targets = torch.randn(8, 10) loss = MSELoss() lr = 1e-3 wd = 0.1 momentum = 0.95 opt_ref_muon = Muon( params=model0.parameters(), lr=lr, weight_decay=wd, momentum=momentum, nesterov=nesterov, adjust_lr_fn="original", ) opt_exp_muon = Muon( params=model1.parameters(), lr=lr, weight_decay=wd, momentum=momentum, nesterov=nesterov, adjust_lr_fn="original", use_addmm=True, ) out_ref = model0(inputs) loss_ref = loss(out_ref, targets) opt_ref_muon.zero_grad() loss_ref.backward() opt_ref_muon.step() out_exp = model1(inputs) loss_exp = loss(out_exp, targets) opt_exp_muon.zero_grad() loss_exp.backward() opt_exp_muon.step() for p_ref, p_exp in zip(model0.parameters(), model1.parameters()): torch.testing.assert_close(p_ref, p_exp) ``` shows numeric difference, but this is expected on bf16 precision: ``` Mismatched elements: 96 / 100 (96.0%) Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed) Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed) ``` ~~Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.~~ Update: since `addmm` fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make `addmm` the default and only option. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161379 Approved by: https://github.com/janeyx99
2310 lines
84 KiB
Python
2310 lines
84 KiB
Python
# mypy: ignore-errors
|
||
|
||
import functools
|
||
import itertools
|
||
import sys
|
||
import unittest
|
||
from copy import deepcopy
|
||
from enum import Enum
|
||
from typing import Any, Union
|
||
|
||
import torch
|
||
from torch import Tensor
|
||
from torch.nn import Parameter
|
||
from torch.optim import (
|
||
Adadelta,
|
||
Adafactor,
|
||
Adagrad,
|
||
Adam,
|
||
Adamax,
|
||
AdamW,
|
||
ASGD,
|
||
LBFGS,
|
||
Muon,
|
||
NAdam,
|
||
Optimizer,
|
||
RAdam,
|
||
RMSprop,
|
||
Rprop,
|
||
SGD,
|
||
SparseAdam,
|
||
)
|
||
from torch.optim.lr_scheduler import (
|
||
ConstantLR,
|
||
ExponentialLR,
|
||
LinearLR,
|
||
PolynomialLR,
|
||
ReduceLROnPlateau,
|
||
StepLR,
|
||
)
|
||
from torch.testing._internal.common_device_type import tol, toleranceOverride
|
||
from torch.testing._internal.common_methods_invocations import DecorateInfo
|
||
from torch.testing._internal.common_utils import (
|
||
_TestParametrizer,
|
||
skipIfMPS,
|
||
skipIfTorchDynamo,
|
||
skipIfXpu,
|
||
TEST_WITH_TORCHDYNAMO,
|
||
)
|
||
from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
|
||
|
||
|
||
class OptimizerInput:
|
||
"""Contains args / kwargs to be passed to an optimizer constructor."""
|
||
|
||
__slots__ = ["params", "kwargs", "desc"]
|
||
|
||
def __init__(
|
||
self,
|
||
params: Union[
|
||
list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
|
||
],
|
||
kwargs: dict[str, Any],
|
||
desc: str = "",
|
||
):
|
||
# params can be a list of Tensors OR param_groups OR None
|
||
self.params = params
|
||
self.kwargs = kwargs
|
||
self.desc = desc
|
||
|
||
def __repr__(self):
|
||
return f"params={self.params}, kwargs={self.kwargs}, desc={self.desc}"
|
||
|
||
|
||
class OptimizerErrorEnum(Enum):
|
||
"""Enumerates when an error is raised when testing optimizers."""
|
||
|
||
CONSTRUCTION_ERROR = 0
|
||
STEP_ERROR = 1
|
||
|
||
|
||
class ErrorOptimizerInput:
|
||
"""
|
||
An OptimizerInput that will cause the optimizer to throw an error when constructed.
|
||
Includes the type and string of the resulting error.
|
||
"""
|
||
|
||
__slots__ = ["optimizer_error_input", "error_on", "error_type", "error_regex"]
|
||
|
||
def __init__(
|
||
self,
|
||
optimizer_error_input,
|
||
*,
|
||
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
|
||
error_type=RuntimeError,
|
||
error_regex="",
|
||
):
|
||
self.optimizer_error_input = optimizer_error_input
|
||
self.error_on = error_on
|
||
self.error_type = error_type
|
||
self.error_regex = error_regex
|
||
|
||
|
||
class OptimizerInfo:
|
||
"""Optimizer information to be used in testing."""
|
||
|
||
def __init__(
|
||
self,
|
||
optim_cls: Optimizer, # Class object for the Optimizer under test
|
||
*,
|
||
# Function to generate optimizer inputs EXCLUDING params. We delegate params responsibility
|
||
# to the test using the OptimizerInfo. OptimizerInput.params is likely None.
|
||
# Can optionally take in device to filter out certain unsupported configs
|
||
optim_inputs_func,
|
||
# Tuple of lambdas to generate LRScheduler instances to run with the optimizer for the
|
||
# LRScheduler tests like test_forloop_goes_right_direction with_lrsched.
|
||
# We DO NOT expect to thoroughly test LRSchedulers through the optimizers, so not every
|
||
# LRScheduler configuration will be included. See test_lrscheduler.py for that instead.
|
||
# A few optimizers like SGD and Adam will test more LRSchedulers.
|
||
scheduler_inputs=(
|
||
[
|
||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||
lambda opt: ReduceLROnPlateau(opt),
|
||
],
|
||
),
|
||
# A subset of the global-cliquey flags (fused, foreach, differentiable) the optimizer
|
||
# supports. See NOTE: [optimizer kwarg categories] for what global-cliquey means.
|
||
supported_impls: tuple[str, ...] = ("foreach", "differentiable"),
|
||
# A subset of all flags, signifying which ones were only supported after the
|
||
# original optimizer had already been released. aka impls where we need to check BC.
|
||
not_og_supported_flags: tuple[str, ...] = (
|
||
"foreach",
|
||
"differentiable",
|
||
"maximize",
|
||
"capturable",
|
||
),
|
||
# the optim supports passing in sparse gradients as well as dense grads
|
||
supports_sparse: bool = False,
|
||
# the optimizer constructor supports passing in capturable as a kwarg
|
||
has_capturable_arg: bool = False,
|
||
# the optim only supports one config: sparse grads w/ dense params, see SparseAdam
|
||
only_supports_sparse_grads: bool = False,
|
||
# Tuple of (optimizer kwargs, schedulers_constructors) specifically for sparse tests,
|
||
# with especially tuned hyperparameters. These only apply if the optimizer supports
|
||
# sparse parameters or grads.
|
||
metadata_for_sparse=({}, []),
|
||
# the optim supports complex parameters
|
||
supports_complex: bool = True,
|
||
# whether the optimizer.step() function requires a closure to be passed
|
||
step_requires_closure: bool = False,
|
||
# whether the optimizer supports per-param options with parameter groups
|
||
supports_param_groups: bool = True,
|
||
# whether the optimizer supports parameters on multiple devices
|
||
supports_multiple_devices: bool = True,
|
||
skips=(), # Indicates which tests to skip
|
||
decorators=None, # Additional decorators to apply to generated tests
|
||
optim_error_inputs_func=None, # Function to generate optim inputs that error
|
||
supports_fused_on: tuple[str, ...] = (),
|
||
):
|
||
self.optim_cls = optim_cls
|
||
self.optim_inputs_func = optim_inputs_func
|
||
self.scheduler_inputs = scheduler_inputs
|
||
self.supported_impls = supported_impls
|
||
self.not_og_supported_flags = not_og_supported_flags
|
||
self.supports_sparse = supports_sparse
|
||
self.has_capturable_arg = has_capturable_arg
|
||
self.metadata_for_sparse = metadata_for_sparse
|
||
self.only_supports_sparse_grads = only_supports_sparse_grads
|
||
self.supports_complex = supports_complex
|
||
self.step_requires_closure = step_requires_closure
|
||
self.supports_param_groups = supports_param_groups
|
||
self.supports_multiple_devices = supports_multiple_devices
|
||
self.decorators = (
|
||
*(decorators if decorators else []),
|
||
*(skips if skips else []),
|
||
)
|
||
self.optim_error_inputs_func = optim_error_inputs_func
|
||
self.supports_fused_on = supports_fused_on
|
||
|
||
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
|
||
result = []
|
||
for decorator in self.decorators:
|
||
if isinstance(decorator, DecorateInfo):
|
||
if decorator.is_active(
|
||
test_class, test_name, device, dtype, param_kwargs
|
||
):
|
||
result.extend(decorator.decorators)
|
||
else:
|
||
result.append(decorator)
|
||
return result
|
||
|
||
@property
|
||
def name(self):
|
||
return self.optim_cls.__name__
|
||
|
||
|
||
class optims(_TestParametrizer):
|
||
"""Decorator for specifying a list of optimizers over which to run a test."""
|
||
|
||
def __init__(self, optim_info_iterable, dtypes=None):
|
||
self.optim_info_list = list(optim_info_iterable)
|
||
|
||
# optimizers aren't limited to be one dtype as parameters can have different dtypes
|
||
# We default to torch.float32, but dtypes should be specified through passed in
|
||
# parameters.
|
||
self.dtypes = dtypes if dtypes is not None else [torch.float32]
|
||
|
||
def _parametrize_test(self, test, generic_cls, device_cls):
|
||
if device_cls is None:
|
||
raise RuntimeError(
|
||
"The @optims decorator is only intended to be used in a device-specific "
|
||
"context; use it with instantiate_device_type_tests() instead of "
|
||
"instantiate_parametrized_tests()"
|
||
)
|
||
|
||
for optim_info, dtype in itertools.product(self.optim_info_list, self.dtypes):
|
||
# Construct the test name; device / dtype parts are handled outside.
|
||
# See [Note: device and dtype suffix placement]
|
||
test_name = optim_info.name
|
||
|
||
# Construct parameter kwargs to pass to the test.
|
||
param_kwargs = {"optim_info": optim_info, "dtype": dtype}
|
||
|
||
try:
|
||
|
||
@functools.wraps(test)
|
||
def test_wrapper(*args, **kwargs):
|
||
return test(*args, **kwargs)
|
||
|
||
decorator_fn = functools.partial(
|
||
optim_info.get_decorators,
|
||
generic_cls.__name__,
|
||
test.__name__,
|
||
device_cls.device_type,
|
||
dtype,
|
||
)
|
||
|
||
yield (test_wrapper, test_name, param_kwargs, decorator_fn)
|
||
except Exception as ex:
|
||
# Provides an error message for debugging before rethrowing the exception
|
||
print(
|
||
f"Failed to instantiate {test_name} for module {optim_info.name}!"
|
||
)
|
||
raise ex
|
||
|
||
|
||
# Helper function for generating error inputs for all optimizers, used below.
|
||
def get_error_inputs_for_all_optims(device, dtype):
|
||
if _get_device_type(device) == "cpu":
|
||
# Creating 2D parameters for compatibility with Muon.
|
||
sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
|
||
sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
|
||
return [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=sample_param,
|
||
kwargs={},
|
||
desc="invalid param type",
|
||
),
|
||
error_type=TypeError,
|
||
error_regex="params argument given to the optimizer should be an iterable of Tensors or dicts",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[sample_param, sample_param],
|
||
kwargs={},
|
||
desc="a param group cannot have duplicate parameters",
|
||
),
|
||
error_type=UserWarning,
|
||
error_regex=".*a parameter group with duplicate parameters.*",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[{"params": sample_param}, {"params": sample_param}],
|
||
kwargs={},
|
||
desc="duplicate parameters should not occur across param groups either",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="some parameters appear in more than one parameter group",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=torch.tensor([0.001, 0.001])),
|
||
desc="Tensor lr must be 1-element",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Tensor lr must be 1-element",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[("weight", sample_param), sample_param2],
|
||
kwargs={},
|
||
desc="all optimizer params should be with/without names",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="all optimizer params should be with/without names. Some param names are missing",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[
|
||
{"params": [sample_param], "lr": 1e-2},
|
||
{"params": [("weight", sample_param2)]},
|
||
],
|
||
kwargs={},
|
||
desc="all optimizer param groups should be with/without names.",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="all optimizer param groups should be with/without names. "
|
||
"cannot add param group with names to the optimizer",
|
||
),
|
||
]
|
||
else:
|
||
return []
|
||
|
||
|
||
# ------------------------------------------------------------------------------------------
|
||
# NOTE: [optimizer kwarg categories]
|
||
# We categorize optimizer kwargs as 3 types:
|
||
# 1. optimizer-specific flags are like amsgrad or rho or beta, flags that are specific to
|
||
# algorithms and thus only show up for certain optimizers. There are many of these, so I
|
||
# do not bother gathering them all and listing them here. The converse to these would be
|
||
# global flags that every optimizer ideally _should_ support. We break global flags into
|
||
# 2 further categories and list them all below.
|
||
# 2. global-friendly = ["lr", "weight_decay", "maximize", "capturable"]
|
||
# global-friendly flags are global flags who play nicely with all other global flags,
|
||
# i.e., are mutually exclusive in function. This means that any pair of the following
|
||
# flags can be toggled at once (e.g., maximize and weight_decay). Furthermore, any of the
|
||
# following flags theoretically can be enabled with ANY other global flag, including the
|
||
# cliquey ones (e.g, capturable and foreach).
|
||
# 3. global-cliquey = ["foreach", "fused", "differentiable"]
|
||
# global-cliquey flags are global flags that do NOT coexist with other cliquey flags,
|
||
# usually because they contradict each other in function. For example, one should not flip
|
||
# both foreach AND fused to True, because they are two differing performance optimizations
|
||
# in which you can only opt into one.
|
||
#
|
||
# The following optim_inputs_func_* sampling functions only return constructor combinations of
|
||
# optimizer-specific and global-friendly flags. This is because we are confident they would mesh
|
||
# well with additional kwargs. On the flip side of the same coin, we reserve setting the
|
||
# global-cliquey flags to individual tests and fully expect tests to edit OptimizerInput.kwargs.
|
||
|
||
|
||
def optim_inputs_func_adadelta(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "capturable": True},
|
||
desc="capturable with weight decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": torch.tensor(0.001), "capturable": True},
|
||
desc="Tensor lr with capturable",
|
||
),
|
||
]
|
||
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize, weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_adadelta(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, rho=1.1),
|
||
desc="rho should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid rho value: 1.1",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_adafactor(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "lr": 0.01},
|
||
desc="nonzero weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"beta2_decay": -1.0},
|
||
desc="non-default beta2_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"d": 1.5},
|
||
desc="non-default clipping threshold d",
|
||
),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_adafactor(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
|
||
complex_param.grad = torch.rand_like(complex_param)
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(eps=(-1e-30, 1e-3)),
|
||
desc="epsilon1 should be >= 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="epsilon1 should be >= 0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(d=0.0),
|
||
desc="invalid d",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Clipping threshold d should be >= 1",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(beta2_decay=0.8),
|
||
desc="invalid beta2_decay",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="beta2_decay should be <= 0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[complex_param],
|
||
kwargs=dict(),
|
||
desc="does not support complex parameters",
|
||
),
|
||
error_type=RuntimeError,
|
||
error_regex="Adafactor does not support complex parameters",
|
||
error_on=OptimizerErrorEnum.STEP_ERROR,
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_adagrad(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"initial_accumulator_value": 0.1, "weight_decay": 0.1},
|
||
desc="initial_accumulator_value",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": 0.1, "lr_decay": 0.5, "weight_decay": 0.1},
|
||
desc="lr_decay",
|
||
), # TODO: Move out to testing in param_group?
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": torch.tensor(0.001)},
|
||
desc="Tensor lr",
|
||
),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_adagrad(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, lr_decay=-0.5),
|
||
desc="lr_decay must be bigger than 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid lr_decay value: -0.5",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
# TODO: consider tensor LR! See multi_tensor_optimizer_configs in test_optim.py --> tensor LR should work
|
||
# with all implementation code paths...
|
||
def optim_inputs_func_adam(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "amsgrad": True, "capturable": True},
|
||
desc="capturable, amsgrad",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": torch.tensor(0.001), "amsgrad": True, "capturable": True},
|
||
desc="Tensor lr with capturable and amsgrad",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
|
||
"amsgrad": True,
|
||
"capturable": True,
|
||
},
|
||
desc="Tensor lr, Tensor betas, with capturable and amsgrad",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
|
||
"amsgrad": False,
|
||
"capturable": True,
|
||
},
|
||
desc="Tensor lr, Tensor betas, with capturable",
|
||
),
|
||
]
|
||
mps_supported_configs = [
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": torch.tensor(0.01)}, desc="Tensor lr"
|
||
),
|
||
]
|
||
|
||
total = (
|
||
[
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "amsgrad": True},
|
||
desc="amsgrad",
|
||
),
|
||
]
|
||
+ (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
+ (mps_supported_configs if _get_device_type(device) == "mps" else [])
|
||
)
|
||
if dtype in (torch.float16,):
|
||
for input in total:
|
||
"""
|
||
Too small eps will make denom to be zero for low precision dtype
|
||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||
For example,
|
||
>>> a
|
||
tensor([0.], dtype=torch.float16)
|
||
>>> a + 1e-8
|
||
tensor([0.], dtype=torch.float16)
|
||
"""
|
||
input.kwargs["eps"] = 0.1
|
||
return total
|
||
|
||
|
||
def optim_error_inputs_func_adam(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
||
desc="beta1 should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid beta parameter at index 0: 1.0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, weight_decay=-1),
|
||
desc="weight_decay should > 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid weight_decay value: -1",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=torch.tensor(0.001), foreach=True),
|
||
desc="lr as Tensor doesn't work with foreach & not capturable",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="lr as a Tensor is not supported for capturable=False and foreach=True",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(0.9, torch.tensor(0.99))),
|
||
desc="betas must be either both floats or both Tensors",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="betas must be either both floats or both Tensors",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(torch.tensor(0.9), 0.99)),
|
||
desc="betas must be either both floats or both Tensors",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="betas must be either both floats or both Tensors",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(
|
||
lr=1e-2,
|
||
betas=(torch.tensor(0.9), torch.tensor(0.99)),
|
||
foreach=True,
|
||
),
|
||
desc=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex=r"betas\[0\] as a Tensor is not supported for capturable=False and foreach=True",
|
||
),
|
||
]
|
||
if _get_device_type(device) == "cuda":
|
||
sample_tensor = torch.empty((), device=device, dtype=dtype)
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[sample_tensor],
|
||
kwargs={"foreach": True, "fused": True},
|
||
desc="`fused` and `foreach` cannot be `True` together",
|
||
),
|
||
error_type=RuntimeError,
|
||
error_regex="`fused` and `foreach` cannot be `True` together",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[sample_tensor],
|
||
kwargs={"fused": True, "differentiable": True},
|
||
desc="`fused` does not support `differentiable`",
|
||
),
|
||
error_type=RuntimeError,
|
||
error_regex="`fused` does not support `differentiable`",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_adamax(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.9, "maximize": True, "capturable": True},
|
||
desc="capturable, maximize, weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0, "maximize": True, "capturable": True},
|
||
desc="capturable, maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.9, "maximize": False, "capturable": True},
|
||
desc="capturable, weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"weight_decay": 0.9,
|
||
"maximize": False,
|
||
"capturable": True,
|
||
},
|
||
desc="capturable, weight_decay, tensor LR",
|
||
),
|
||
]
|
||
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.1}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"maximize": True},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize, weight_decay",
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_adamax(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(0.0, 1.0)),
|
||
desc="beta2 should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid beta parameter at index 1: 1.0",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_adamw(device, dtype=None):
|
||
return optim_inputs_func_adam(device, dtype)
|
||
|
||
|
||
def optim_error_inputs_func_adamw(device, dtype):
|
||
return optim_error_inputs_func_adam(device, dtype)
|
||
|
||
|
||
def optim_inputs_func_asgd(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"maximize": True, "capturable": True},
|
||
desc="maximize, capturable",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "capturable": True},
|
||
desc="weight_decay, capturable",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
|
||
desc="maximize, weight_decay, capturable",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"weight_decay": 0.1,
|
||
"maximize": True,
|
||
"capturable": True,
|
||
},
|
||
desc="maximize, weight_decay, capturable, tensor LR",
|
||
),
|
||
]
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
|
||
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
|
||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize, nonzero weight_decay",
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_asgd(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, weight_decay=-0.5),
|
||
desc="weight_decay should > 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid weight_decay value: -0.5",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_lbfgs(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
|
||
),
|
||
OptimizerInput(
|
||
params=None, kwargs={"tolerance_grad": 1e-6}, desc="tolerance_grad"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"line_search_fn": "strong_wolfe"},
|
||
desc="strong_wolfe",
|
||
),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_lbfgs(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_muon(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.2},
|
||
desc="non-default weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"momentum": 0.8},
|
||
desc="non-default momentum",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"ns_steps": 6},
|
||
desc="passing alternative ns_steps",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"ns_coefficients": (3.4, -4.7, 2.0),
|
||
},
|
||
desc="passing alternative ns_coefficients",
|
||
),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_muon(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
|
||
complex_param.grad = torch.rand_like(complex_param)
|
||
non_2d_param = torch.rand(2, 3, 4, device=device, dtype=dtype)
|
||
non_2d_param.grad = torch.rand_like(non_2d_param)
|
||
param = torch.rand(2, 3, device=device, dtype=dtype)
|
||
param.grad = torch.rand_like(param)
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[non_2d_param],
|
||
kwargs=dict(),
|
||
desc="only support 2D parameters",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Muon only supports 2D parameters",
|
||
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[param],
|
||
kwargs={"adjust_lr_fn": "arbitrary"},
|
||
desc="only support `original` and `match_rms_adamw`",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Adjust learning rate function arbitrary is not supported",
|
||
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[complex_param],
|
||
kwargs=dict(),
|
||
desc="does not support complex parameters",
|
||
),
|
||
error_type=RuntimeError,
|
||
error_regex="Muon does not support complex parameters",
|
||
error_on=OptimizerErrorEnum.STEP_ERROR,
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_nadam(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
|
||
desc="weight_decay, capturable",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"weight_decay": 0.9,
|
||
"momentum_decay": 6e-3,
|
||
"decoupled_weight_decay": True,
|
||
"capturable": True,
|
||
},
|
||
desc="decoupled_weight_decay, capturable",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"weight_decay": 0.9,
|
||
"momentum_decay": 6e-3,
|
||
"decoupled_weight_decay": True,
|
||
"capturable": True,
|
||
},
|
||
desc="decoupled_weight_decay, capturable",
|
||
),
|
||
]
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"momentum_decay": 6e-3},
|
||
desc="non-zero momentum_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"weight_decay": 0.1,
|
||
},
|
||
desc="weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "momentum_decay": 6e-3},
|
||
desc="weight_decay, momentum_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"weight_decay": 0.1,
|
||
"momentum_decay": 6e-3,
|
||
"decoupled_weight_decay": True,
|
||
},
|
||
desc="decoupled_weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_nadam(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
||
desc="beta1 should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid beta parameter at index 0: 1.0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, momentum_decay=-0.2),
|
||
desc="momentum_decay should > 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid momentum_decay value: -0.2",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
# Weird story bro, NAdam and RAdam do not have maximize.
|
||
def optim_inputs_func_radam(device=None, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"capturable": True,
|
||
"weight_decay": 0.1,
|
||
},
|
||
desc="capturable, weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"capturable": True,
|
||
"weight_decay": 0.1,
|
||
"decoupled_weight_decay": True,
|
||
},
|
||
desc="capturable, weight_decay, decoupled_weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"lr": torch.tensor(0.001),
|
||
"capturable": True,
|
||
"weight_decay": 0.1,
|
||
"decoupled_weight_decay": True,
|
||
},
|
||
desc="capturable, weight_decay, decoupled_weight_decay, tensor LR",
|
||
),
|
||
]
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 2e-3}, desc="non-default lr"),
|
||
OptimizerInput(params=None, kwargs={"eps": 1e-6}, desc="non-default eps"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "decoupled_weight_decay": True},
|
||
desc="decoupled_weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_radam(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
||
desc="beta1 should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid beta parameter at index 0: 1.0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, weight_decay=-1),
|
||
desc="weight_decay should > 0",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid weight_decay value: -1",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_rmsprop(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True, "capturable": True},
|
||
desc="capturable, maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": torch.tensor(0.001), "capturable": True},
|
||
desc="Tensor lr with capturable",
|
||
),
|
||
]
|
||
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 1e-3}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"maximize": True,
|
||
},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "centered": True},
|
||
desc="centered",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"maximize": True,
|
||
"weight_decay": 0.1,
|
||
},
|
||
desc="maximize, weight_decay",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "centered": True, "momentum": 0.1},
|
||
desc="momentum",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={
|
||
"weight_decay": 0.1,
|
||
"centered": True,
|
||
"momentum": 0.1,
|
||
"maximize": True,
|
||
},
|
||
desc="maximize, centered, weight_decay, w/ momentum",
|
||
),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_rmsprop(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, momentum=-1.0),
|
||
desc="momentum should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid momentum value: -1.0",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_rprop(device, dtype=None):
|
||
cuda_supported_configs = [
|
||
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"lr": torch.tensor(0.001), "capturable": True},
|
||
desc="Tensor lr with capturable",
|
||
),
|
||
]
|
||
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 2e-4}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"etas": (0.5, 1.5)}, desc="non-default etas"
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"step_sizes": (2e-6, 100)},
|
||
desc="non-default step_sizes",
|
||
),
|
||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||
] + (cuda_supported_configs if _get_device_type(device) == "cuda" else [])
|
||
|
||
|
||
def optim_error_inputs_func_rprop(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, etas=(1.0, 0.5)),
|
||
desc="0 < eta1 < 1 < eta2",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid eta values: 1.0, 0.5",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_sgd(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="non-default lr"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="tensor lr"
|
||
),
|
||
OptimizerInput(
|
||
params=None, kwargs={"weight_decay": 0.5}, desc="non-zero weight_decay"
|
||
),
|
||
OptimizerInput(params=None, kwargs={"momentum": 0.9}, desc="momentum"),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"weight_decay": 0.1, "maximize": True},
|
||
desc="maximize",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"momentum": 0.9, "dampening": 0.5},
|
||
desc="dampening",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"momentum": 0.9, "weight_decay": 0.1},
|
||
desc="weight_decay w/ momentum",
|
||
),
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs={"momentum": 0.9, "nesterov": True, "weight_decay": 0.1},
|
||
desc="nesterov",
|
||
),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_sgd(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, momentum=-0.5),
|
||
desc="momentum should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid momentum value: -0.5",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def optim_inputs_func_sparseadam(device, dtype=None):
|
||
return [
|
||
OptimizerInput(params=None, kwargs={}, desc="default"),
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": 0.01}, desc="non-default lr"
|
||
), # TODO: Move out to testing in param_group?
|
||
OptimizerInput(
|
||
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
|
||
),
|
||
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
|
||
]
|
||
|
||
|
||
def optim_error_inputs_func_sparseadam(device, dtype):
|
||
error_inputs = get_error_inputs_for_all_optims(device, dtype)
|
||
|
||
if _get_device_type(device) == "cpu":
|
||
error_inputs += [
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=None,
|
||
kwargs=dict(lr=1e-2, betas=(1.0, 0.0)),
|
||
desc="beta1 should be between 0 and 1",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="Invalid beta parameter at index 0: 1.0",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[
|
||
torch.zeros(
|
||
3, layout=torch.sparse_coo, device=device, dtype=dtype
|
||
)
|
||
],
|
||
kwargs={},
|
||
desc="dense params required",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="SparseAdam requires dense parameter tensors",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[
|
||
{
|
||
"params": [
|
||
torch.zeros(
|
||
3,
|
||
layout=torch.sparse_coo,
|
||
device=device,
|
||
dtype=dtype,
|
||
)
|
||
]
|
||
}
|
||
],
|
||
kwargs={},
|
||
desc="dense params required in param_groups",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="SparseAdam requires dense parameter tensors",
|
||
),
|
||
ErrorOptimizerInput(
|
||
OptimizerInput(
|
||
params=[torch.rand(2, 3, device=device, dtype=torch.complex64)],
|
||
kwargs={},
|
||
desc="complex not supported",
|
||
),
|
||
error_type=ValueError,
|
||
error_regex="SparseAdam does not support complex parameters",
|
||
),
|
||
]
|
||
return error_inputs
|
||
|
||
|
||
def _get_device_type(device: Union[str, torch.device]) -> str:
|
||
# Returns the device type as a string, e.g., "cpu" or "cuda"
|
||
if isinstance(device, torch.device):
|
||
device = str(device.type)
|
||
assert isinstance(device, str)
|
||
return device.split(":")[0]
|
||
|
||
|
||
def _get_optim_inputs_including_global_cliquey_kwargs(
|
||
device, dtype, optim_info, skip=()
|
||
) -> list[OptimizerInput]:
|
||
"""
|
||
Return a list of all configs for a given optimizer as a list of OptimizerInputs,
|
||
including configs that have supported global cliquey kwargs (foreach, fused,
|
||
differentiable) based on optim_info.supported_impls.
|
||
|
||
The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
|
||
intentionally do NOT include global cliquey kwargs to give flexibility to tests.
|
||
For example, testing correctness between toggling foreach on and off is now
|
||
trivial. That said, we sometimes want to test for all possible configs on an
|
||
optimizer including all supported flags, so this helper returns all optim inputs.
|
||
"""
|
||
assert all(x in ["foreach", "fused", "differentiable"] for x in skip), (
|
||
"skip must be a subset of ['foreach', 'fused', 'differentiable']"
|
||
)
|
||
|
||
optim_inputs = optim_info.optim_inputs_func(device)
|
||
|
||
supported_impls = tuple(
|
||
x
|
||
for x in optim_info.supported_impls
|
||
if x not in skip
|
||
and (_get_device_type(device) in optim_info.supports_fused_on or x != "fused")
|
||
and (
|
||
_get_device_type(device) in _get_foreach_kernels_supported_devices()
|
||
or x != "foreach"
|
||
)
|
||
)
|
||
|
||
all_optim_inputs = []
|
||
for optim_input in optim_inputs:
|
||
# Add the base config where all the flags are False
|
||
base_kwargs = deepcopy(optim_input.kwargs)
|
||
if len(supported_impls) != 0:
|
||
for flag in supported_impls:
|
||
base_kwargs[flag] = False
|
||
all_optim_inputs.append(
|
||
OptimizerInput(params=None, kwargs=base_kwargs, desc=optim_input.desc)
|
||
)
|
||
else:
|
||
all_optim_inputs.append(optim_input)
|
||
# Add a config for when each of the global cliquey kwargs is True
|
||
# Note that in [optimizer kwarg categories], these kwargs are mutually
|
||
# exclusive, so we do not need to product them together.
|
||
for flag in supported_impls:
|
||
new_kwargs = deepcopy(base_kwargs)
|
||
new_kwargs[flag] = True
|
||
all_optim_inputs.append(
|
||
OptimizerInput(
|
||
params=None, kwargs=new_kwargs, desc=f"{optim_input.desc} & {flag}"
|
||
)
|
||
)
|
||
return all_optim_inputs
|
||
|
||
|
||
# Database of OptimizerInfo entries in alphabetical order.
|
||
optim_db: list[OptimizerInfo] = [
|
||
OptimizerInfo(
|
||
Adadelta,
|
||
optim_inputs_func=optim_inputs_func_adadelta,
|
||
optim_error_inputs_func=optim_error_inputs_func_adadelta,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
# Note on tolerances:
|
||
# test_correctness_Adadelta_cuda_float32
|
||
# Mismatched elements: 10 / 100 (10.0%)
|
||
# Greatest absolute difference: 4.838220775127411e-05 at index (7, 4) (up to 1e-05 allowed)
|
||
# Greatest relative difference: 0.007270356640219688 at index (7, 2) (up to 1e-05 allowed)
|
||
# This is due to floating point ordering error + usage of sqrt
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{
|
||
torch.float32: tol(
|
||
rtol=5.5e-4,
|
||
atol=5e-5,
|
||
)
|
||
}
|
||
),
|
||
"CompiledOptimizerParityTests",
|
||
"test_correctness",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Adafactor,
|
||
optim_inputs_func=optim_inputs_func_adafactor,
|
||
optim_error_inputs_func=optim_error_inputs_func_adafactor,
|
||
supported_impls=("foreach",),
|
||
not_og_supported_flags=("foreach",),
|
||
supports_complex=False,
|
||
skips=(
|
||
DecorateInfo(
|
||
unittest.skip("See #133268 regarding dtype being None"),
|
||
"CompiledOptimizerParityTests",
|
||
"test_correctness",
|
||
device_type="cuda",
|
||
active_if=lambda kwargs: kwargs.get("use_closure", False),
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_can_load_older_state_dict",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_deepcopy_copies_all_public_attrs",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_foreach_large_tensor",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_foreach_matches_forloop",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_load_nontensor_step",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_mixed_device_dtype",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_lr",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_weight_decay",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_peak_memory_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_save_load_equality_with_weights_only",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028 regarding copy not supported"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_state_dict_deterministic",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_step_is_noop_for_zero_grads",
|
||
device_type="cuda",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("See #133268 regarding dtype being None"),
|
||
"CompiledOptimizerParityTests",
|
||
"test_correctness",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_can_load_older_state_dict",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_deepcopy_copies_all_public_attrs",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_load_nontensor_step",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_lr",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_weight_decay",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_save_load_equality_with_weights_only",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_state_dict_deterministic",
|
||
device_type="xpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #133268 regarding dtype being None"),
|
||
"TestOptimRenewed",
|
||
"test_step_is_noop_for_zero_grads",
|
||
device_type="xpu",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Adagrad,
|
||
optim_inputs_func=optim_inputs_func_adagrad,
|
||
optim_error_inputs_func=optim_error_inputs_func_adagrad,
|
||
supported_impls=("foreach", "differentiable", "fused"),
|
||
not_og_supported_flags=(
|
||
"foreach",
|
||
"differentiable",
|
||
"fused",
|
||
"maximize",
|
||
"capturable",
|
||
),
|
||
supports_fused_on=("cpu",),
|
||
supports_sparse=True,
|
||
metadata_for_sparse=(
|
||
{"lr": 0.1, "weight_decay": 0, "lr_decay": 0},
|
||
[
|
||
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
|
||
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
|
||
],
|
||
),
|
||
decorators=(
|
||
DecorateInfo(
|
||
# Note on tolerances:
|
||
# difference comes from the fact that the non fused kernel have
|
||
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
|
||
# to make sure there is no discrepancies between cuda fused kernel
|
||
# and cpu fused kernel
|
||
toleranceOverride(
|
||
{
|
||
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
|
||
torch.float16: tol(atol=5e-3, rtol=5e-3),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_fused_matches_forloop",
|
||
),
|
||
),
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Adam,
|
||
optim_inputs_func=optim_inputs_func_adam,
|
||
scheduler_inputs=(
|
||
[lambda opt: ExponentialLR(opt, gamma=0.9)],
|
||
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
|
||
[
|
||
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
|
||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||
],
|
||
[
|
||
lambda opt: ExponentialLR(opt, gamma=0.9),
|
||
lambda opt: ReduceLROnPlateau(opt),
|
||
],
|
||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
||
[
|
||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||
lambda opt: ReduceLROnPlateau(opt),
|
||
],
|
||
),
|
||
optim_error_inputs_func=optim_error_inputs_func_adam,
|
||
supported_impls=("foreach", "differentiable", "fused"),
|
||
has_capturable_arg=True,
|
||
not_og_supported_flags=(
|
||
"foreach",
|
||
"differentiable",
|
||
"fused",
|
||
"maximize",
|
||
"capturable",
|
||
),
|
||
supports_fused_on=("cpu", "cuda", "mps"),
|
||
decorators=(
|
||
# Expected floating point error between fused and compiled forloop
|
||
DecorateInfo(
|
||
toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
|
||
"TestOptimRenewed",
|
||
"test_fused_matches_forloop",
|
||
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
|
||
and kwargs["dtype"] == torch.float64,
|
||
),
|
||
DecorateInfo(
|
||
# Note on tolerances:
|
||
# difference comes from the fact that the non fused kernel have
|
||
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
|
||
# to make sure there is no discrepancies between cuda fused kernel
|
||
# and cpu fused kernel
|
||
toleranceOverride(
|
||
{
|
||
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
|
||
torch.float16: tol(atol=5e-3, rtol=5e-3),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_fused_matches_forloop",
|
||
),
|
||
DecorateInfo(
|
||
# Note on tolerances:
|
||
# Tracking through #127000
|
||
toleranceOverride(
|
||
{
|
||
torch.float32: tol(atol=3e-5, rtol=1.3e-06),
|
||
}
|
||
),
|
||
"TestCudaOptims",
|
||
"test_grad_scaling_autocast_fused_optimizers",
|
||
),
|
||
),
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Adamax,
|
||
optim_inputs_func=optim_inputs_func_adamax,
|
||
optim_error_inputs_func=optim_error_inputs_func_adamax,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("Uses too much memory, even for H100, surprisingly."),
|
||
"TestOptimRenewed",
|
||
"test_foreach_large_tensor",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
AdamW,
|
||
optim_inputs_func=optim_inputs_func_adamw,
|
||
optim_error_inputs_func=optim_error_inputs_func_adamw,
|
||
supported_impls=("foreach", "differentiable", "fused"),
|
||
not_og_supported_flags=(
|
||
"foreach",
|
||
"differentiable",
|
||
"fused",
|
||
"maximize",
|
||
"capturable",
|
||
),
|
||
supports_fused_on=("cpu", "cuda", "mps"),
|
||
has_capturable_arg=True,
|
||
decorators=(
|
||
# Expected error between compiled forloop and fused optimizers
|
||
DecorateInfo(
|
||
toleranceOverride({torch.float64: tol(atol=4.5e-7, rtol=2.2e-6)}),
|
||
"TestOptimRenewed",
|
||
"test_fused_matches_forloop",
|
||
active_if=lambda kwargs: TEST_WITH_TORCHDYNAMO
|
||
and kwargs["dtype"] == torch.float64,
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
# Note on tolerances:
|
||
# difference comes from the fact that the non fused kernel have
|
||
# more dtype cast operations. We have another test test_fused_cpu_matches_cuda
|
||
# to make sure there is no discrepancies between cuda fused kernel
|
||
# and cpu fused kernel
|
||
{
|
||
torch.bfloat16: tol(atol=5e-3, rtol=5e-3),
|
||
torch.float16: tol(atol=5e-3, rtol=5e-3),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_fused_matches_forloop",
|
||
),
|
||
# Note on tolerances:
|
||
# Tracking through #127000
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{
|
||
torch.float32: tol(
|
||
atol=3e-5,
|
||
rtol=1.3e-06,
|
||
)
|
||
}
|
||
),
|
||
"TestCudaOptims",
|
||
"test_grad_scaling_autocast_fused_optimizers",
|
||
),
|
||
),
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
ASGD,
|
||
optim_inputs_func=optim_inputs_func_asgd,
|
||
optim_error_inputs_func=optim_error_inputs_func_asgd,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{
|
||
torch.float32: tol(atol=1.5e-5, rtol=1e-5),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_step_is_noop_for_zero_grads",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip(
|
||
"ASGD internally changes the weights even with zero grad"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_step_is_noop_for_zero_grads",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
LBFGS,
|
||
optim_inputs_func=optim_inputs_func_lbfgs,
|
||
optim_error_inputs_func=optim_error_inputs_func_lbfgs,
|
||
supported_impls=(),
|
||
step_requires_closure=True,
|
||
supports_param_groups=False,
|
||
supports_multiple_devices=False,
|
||
skips=(
|
||
# Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
|
||
DecorateInfo(
|
||
skipIfMPS,
|
||
"TestOptimRenewed",
|
||
"test_can_load_older_state_dict",
|
||
device_type="mps",
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{
|
||
torch.complex64: tol(
|
||
rtol=4.5e-5,
|
||
atol=5e-5,
|
||
)
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("Does not support param groups"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_lr",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("Does not support param groups"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_weight_decay",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("LBFGS doesn't support multidevice"),
|
||
"TestOptimRenewed",
|
||
"test_forloop_goes_right_direction_multigpu",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip("Does not support param groups"),
|
||
"TestOptimRenewed",
|
||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||
),
|
||
# https://github.com/pytorch/pytorch/issues/131398
|
||
DecorateInfo(
|
||
unittest.expectedFailure,
|
||
"CompiledOptimizerParityTests",
|
||
"test_correctness",
|
||
active_if=lambda kwargs: sys.platform == "darwin"
|
||
and kwargs["use_closure"],
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Muon,
|
||
optim_inputs_func=optim_inputs_func_muon,
|
||
optim_error_inputs_func=optim_error_inputs_func_muon,
|
||
supported_impls=(),
|
||
not_og_supported_flags=(),
|
||
supports_complex=False,
|
||
skips=(
|
||
# Note on numerical differences: `compile` applies different matmul tuning,
|
||
# which leads to deviations compared to eager mode. In the Newton–Schulz
|
||
# iteration for orthogonalization, computations are done in bfloat16, further
|
||
# amplifying these numerical differences.
|
||
DecorateInfo(
|
||
unittest.skip(
|
||
"Expect high difference between compiled and eager due to bfloat16 and iterative process."
|
||
),
|
||
"CompiledOptimizerParityTests",
|
||
"test_correctness",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
NAdam,
|
||
optim_inputs_func=optim_inputs_func_nadam,
|
||
optim_error_inputs_func=optim_error_inputs_func_nadam,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors, https://github.com/pytorch/pytorch/issues/117150"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_load_nontensor_step",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
RAdam,
|
||
optim_inputs_func=optim_inputs_func_radam,
|
||
optim_error_inputs_func=optim_error_inputs_func_radam,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{
|
||
# previously atol=1e-7, rtol=1e-7
|
||
torch.float64: tol(atol=1.5e-7, rtol=1.1e-7)
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_foreach_matches_forloop",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
RMSprop,
|
||
optim_inputs_func=optim_inputs_func_rmsprop,
|
||
optim_error_inputs_func=optim_error_inputs_func_rmsprop,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
|
||
torch.float32: tol(atol=5e-04, rtol=0.01),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_mixed_device_dtype",
|
||
active_if=TEST_WITH_TORCHDYNAMO,
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
Rprop,
|
||
optim_inputs_func=optim_inputs_func_rprop,
|
||
optim_error_inputs_func=optim_error_inputs_func_rprop,
|
||
supported_impls=("foreach", "differentiable"),
|
||
has_capturable_arg=True,
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("See #116028"),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
SGD,
|
||
optim_inputs_func=optim_inputs_func_sgd,
|
||
scheduler_inputs=(
|
||
[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
|
||
[
|
||
lambda opt: LinearLR(
|
||
opt, start_factor=0.4, end_factor=0.8, total_iters=4
|
||
)
|
||
],
|
||
[
|
||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||
lambda opt: LinearLR(
|
||
opt, start_factor=0.4, end_factor=0.6, total_iters=4
|
||
),
|
||
],
|
||
[
|
||
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
|
||
lambda opt: ExponentialLR(opt, gamma=0.99),
|
||
lambda opt: ReduceLROnPlateau(opt),
|
||
],
|
||
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
|
||
[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
|
||
[
|
||
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
|
||
lambda opt: ReduceLROnPlateau(opt),
|
||
],
|
||
),
|
||
optim_error_inputs_func=optim_error_inputs_func_sgd,
|
||
supported_impls=("foreach", "differentiable", "fused"),
|
||
not_og_supported_flags=(
|
||
"foreach",
|
||
"differentiable",
|
||
"fused",
|
||
"maximize",
|
||
"capturable",
|
||
),
|
||
supports_sparse=True,
|
||
metadata_for_sparse=(
|
||
{
|
||
"lr": 4.8e-3,
|
||
"maximize": False,
|
||
"momentum": 0,
|
||
"nesterov": False,
|
||
"weight_decay": 0,
|
||
},
|
||
[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
|
||
),
|
||
supports_fused_on=(
|
||
"cpu",
|
||
"cuda",
|
||
"mps",
|
||
),
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Errors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_set_default_dtype_works_with_foreach",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"Accessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_complex_2d",
|
||
),
|
||
DecorateInfo(
|
||
toleranceOverride(
|
||
{ # previously atol=5-05, rtol=0.001, https://github.com/pytorch/pytorch/issues/116202
|
||
torch.float32: tol(atol=5e-04, rtol=0.007),
|
||
}
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_mixed_device_dtype",
|
||
active_if=TEST_WITH_TORCHDYNAMO,
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo(
|
||
"This test uses mocks, which dynamo does not support"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_defaults_changed_to_foreach",
|
||
),
|
||
),
|
||
),
|
||
OptimizerInfo(
|
||
SparseAdam,
|
||
optim_inputs_func=optim_inputs_func_sparseadam,
|
||
optim_error_inputs_func=optim_error_inputs_func_sparseadam,
|
||
supported_impls=(),
|
||
only_supports_sparse_grads=True,
|
||
metadata_for_sparse=({"lr": 4e-2}, []),
|
||
supports_complex=False, # Missing complex support, see #118153
|
||
skips=(
|
||
DecorateInfo(
|
||
skipIfMPS, # SparseAdam does not support MPS
|
||
"TestOptimRenewed",
|
||
device_type="mps",
|
||
),
|
||
DecorateInfo(
|
||
skipIfXpu(msg="SparseAdam is not yet supported on the XPU stack"),
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_param_groups_lr",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_tensor_lr",
|
||
),
|
||
DecorateInfo(
|
||
unittest.skip(
|
||
"SparseAdam does not support dense gradients, see #116507"
|
||
),
|
||
"TestOptimRenewed",
|
||
"test_can_load_older_state_dict",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_load_nontensor_step",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_forloop_goes_right_direction",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_forloop_goes_right_direction_multigpu",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_param_group_with_lrscheduler_goes_right_direction",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_state_dict_with_cuda_params",
|
||
),
|
||
DecorateInfo(
|
||
skipIfTorchDynamo("cannot call to_sparse on p.grad, see #117184"),
|
||
"TestOptimRenewed",
|
||
"test_deepcopy_copies_all_public_attrs",
|
||
),
|
||
),
|
||
),
|
||
]
|
||
|
||
|
||
class TensorTracker:
|
||
"""
|
||
A utility to track tensor clones in a list, with the expectation of popping them later (in
|
||
order) to make fair comparisons between two multi-step computation. The intended use case is
|
||
usually when comparing two supposed equal computations, such as an optimizer step that each
|
||
individually consists of multiple steps, where numerical deviation could multiply.
|
||
|
||
The goal is to be able to compare and align numbers at every milestone so as to minimize
|
||
numerical discrepancies, and so when the test fails, it is likely a real problem.
|
||
"""
|
||
|
||
def __init__(self, assert_eq_kwargs=None):
|
||
if assert_eq_kwargs is None:
|
||
assert_eq_kwargs = {}
|
||
self.assert_eq_kwargs = assert_eq_kwargs
|
||
self.tensors = []
|
||
|
||
def add(self, tensor):
|
||
"""
|
||
Add a detach().clone()'d version of the tensor
|
||
"""
|
||
self.tensors.append(tensor.detach().clone())
|
||
|
||
# pops from beginning, like a queue and not a stack!
|
||
def pop_check_set(self, tensor_to_set, testcase):
|
||
"""
|
||
Pop the first element in the tensor tracker, assert equality between the popped tensor and
|
||
the input tensor, and then set the input tensor to have the same values as the popped tensor
|
||
(with copy_).
|
||
"""
|
||
testcase.assertGreater(len(self.tensors), 0, "no tensors to pop")
|
||
ref = self.tensors.pop(0)
|
||
|
||
testcase.assertTrue(isinstance(ref, Tensor), f"{type(ref)=}")
|
||
testcase.assertEqual(tensor_to_set, ref, **self.assert_eq_kwargs)
|
||
|
||
with torch.no_grad():
|
||
tensor_to_set.copy_(ref)
|
||
|
||
def all_popped(self):
|
||
return len(self.tensors) == 0
|