Compare commits

...

2 Commits

Author SHA1 Message Date
3541de4b42 Enable dynamo'd test for 116499
ghstack-source-id: 8583662aa16f6d27e292c03ccb3ca67f2a91441b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123469
2024-05-07 11:20:29 -07:00
32220b239b Fallback to eager if we're compiling and capturable=True with
TorchScript BC

ghstack-source-id: 0cb2ec55bac86888b29a34ddd06a604befeb0cb6
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123619
2024-05-07 11:20:29 -07:00
12 changed files with 660 additions and 626 deletions

View File

@ -24,6 +24,9 @@ from .functional_sgd import _FunctionalSGD
from .named_optimizer import _NamedOptimizer
from .utils import as_functional_optim
from warnings import warn
warn("TorchScript support for functional optimizers is"
"deprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead.")
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
# based on RPC being available.

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_scalar_dtype,
_maximize_doc,
@ -227,70 +228,6 @@ Adadelta.__doc__ = (
)
def adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
capturable: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
has_complex: bool = False,
*,
lr: float,
rho: float,
eps: float,
weight_decay: float,
maximize: bool,
):
r"""Functional API that performs Adadelta algorithm computation.
See :class:`~torch.optim.Adadelta` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# We still respect when the user inputs False for foreach.
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adadelta
else:
func = _single_tensor_adadelta
func(
params,
grads,
square_avgs,
acc_deltas,
state_steps,
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
def _single_tensor_adadelta(
params: List[Tensor],
grads: List[Tensor],
@ -428,3 +365,68 @@ def _multi_tensor_adadelta(
torch._foreach_add_(device_params, deltas)
else:
torch._foreach_add_(device_params, deltas, alpha=-lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
def adadelta(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
acc_deltas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
capturable: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
has_complex: bool = False,
*,
lr: float,
rho: float,
eps: float,
weight_decay: float,
maximize: bool,
):
r"""Functional API that performs Adadelta algorithm computation.
See :class:`~torch.optim.Adadelta` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# We still respect when the user inputs False for foreach.
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adadelta
else:
func = _single_tensor_adadelta
func(
params,
grads,
square_avgs,
acc_deltas,
state_steps,
lr=lr,
rho=rho,
eps=eps,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_dispatch_sqrt,
_foreach_doc,
_fused_doc,
@ -314,94 +315,6 @@ Adam.__doc__ = (
)
def adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adam
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adam
else:
func = _single_tensor_adam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _single_tensor_adam(
params: List[Tensor],
grads: List[Tensor],
@ -767,3 +680,92 @@ def _fused_adam(
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
def adam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
"""
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adam
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adam
else:
func = _single_tensor_adam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
has_complex=has_complex,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
)

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_scalar_dtype,
_get_value,
@ -215,69 +216,6 @@ Adamax.__doc__ = (
)
def adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs adamax algorithm computation.
See :class:`~torch.optim.Adamax` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamax
else:
func = _single_tensor_adamax
func(
params,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
capturable=capturable,
)
def _single_tensor_adamax(
params: List[Tensor],
grads: List[Tensor],
@ -452,3 +390,67 @@ def _multi_tensor_adamax(
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
def adamax(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_infs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
eps: float,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs adamax algorithm computation.
See :class:`~torch.optim.Adamax` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamax
else:
func = _single_tensor_adamax
func(
params,
grads,
exp_avgs,
exp_infs,
state_steps,
eps=eps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
capturable=capturable,
)

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_dispatch_sqrt,
_foreach_doc,
_fused_doc,
@ -315,92 +316,6 @@ AdamW.__doc__ = (
)
def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adamw
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamw
else:
func = _single_tensor_adamw
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)
def _single_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
@ -760,3 +675,90 @@ def _fused_adamw(
torch._foreach_sub_(
device_state_steps, [device_found_inf] * len(device_state_steps)
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw)
def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
if fused is None and foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
if foreach and isinstance(lr, Tensor) and not capturable:
foreach = False
if fused is None:
fused = False
if foreach is None:
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if fused and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with fused optimizers")
if fused and not torch.jit.is_scripting():
func = _fused_adamw
elif foreach and not torch.jit.is_scripting():
func = _multi_tensor_adamw
else:
func = _single_tensor_adamw
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_scalar_dtype,
_get_value,
@ -187,63 +188,6 @@ ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
"""
def asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
):
r"""Functional API that performs asgd algorithm computation.
See :class:`~torch.optim.ASGD` for details.
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
func = _single_tensor_asgd
func(
params,
grads,
axs,
mus,
etas,
state_steps,
lambd=lambd,
lr=lr,
t0=t0,
alpha=alpha,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
def _single_tensor_asgd(
params: List[Tensor],
grads: List[Tensor],
@ -441,3 +385,61 @@ def _multi_tensor_asgd(
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
def asgd(
params: List[Tensor],
grads: List[Tensor],
axs: List[Tensor],
mus: List[Tensor],
etas: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lambd: float,
lr: float,
t0: float,
alpha: float,
weight_decay: float,
):
r"""Functional API that performs asgd algorithm computation.
See :class:`~torch.optim.ASGD` for details.
"""
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_asgd
else:
func = _single_tensor_asgd
func(
params,
grads,
axs,
mus,
etas,
state_steps,
lambd=lambd,
lr=lr,
t0=t0,
alpha=alpha,
weight_decay=weight_decay,
maximize=maximize,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -6,6 +6,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_dispatch_sqrt,
_foreach_doc,
_get_scalar_dtype,
@ -258,76 +259,6 @@ NAdam.__doc__ = (
)
def nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float,
):
r"""Functional API that performs NAdam algorithm computation.
See :class:`~torch.optim.NAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if not all(isinstance(t, torch.Tensor) for t in mu_products):
raise RuntimeError(
"API has changed, `mu_products` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_nadam
else:
func = _single_tensor_nadam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
decoupled_weight_decay=decoupled_weight_decay,
eps=eps,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)
def _single_tensor_nadam(
params: List[Tensor],
grads: List[Tensor],
@ -598,3 +529,74 @@ def _multi_tensor_nadam(
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
def nadam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
mu_products: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
momentum_decay: float,
eps: float,
):
r"""Functional API that performs NAdam algorithm computation.
See :class:`~torch.optim.NAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if not all(isinstance(t, torch.Tensor) for t in mu_products):
raise RuntimeError(
"API has changed, `mu_products` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_nadam
else:
func = _single_tensor_nadam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
mu_products,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
momentum_decay=momentum_decay,
decoupled_weight_decay=decoupled_weight_decay,
eps=eps,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)

View File

@ -118,6 +118,27 @@ def _dispatch_sqrt(
return math.sqrt(x)
def _disable_dynamo_if_unsupported(single_tensor_fn=None):
# workaround for torchscript BC
# it requires all called functions to be in the
# global environment at the site at which the
# maybe_fallback closure is created
if single_tensor_fn:
globals()[single_tensor_fn.__name__] = single_tensor_fn
def wrapper(func):
@functools.wraps(func)
def maybe_fallback(self, *args, **kwargs):
if is_compiling() and not kwargs.get("capturable", False):
return torch._disable_dynamo(func(self, *args, **kwargs))
else:
return func(self, *args, **kwargs)
return maybe_fallback
return wrapper
# For any optimizer with a faster implementation, we attempt to default to the
# fastest + stablest whenever possible. For foreach, the requirements are to have
# native params all on CUDA. For fused, there's currently the additional requirement

View File

@ -7,6 +7,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_dispatch_sqrt,
_foreach_doc,
_get_scalar_dtype,
@ -234,67 +235,6 @@ RAdam.__doc__ = (
)
def radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
):
r"""Functional API that performs RAdam algorithm computation.
See :class:`~torch.optim.RAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_radam
else:
func = _single_tensor_radam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)
def _single_tensor_radam(
params: List[Tensor],
grads: List[Tensor],
@ -559,3 +499,65 @@ def _multi_tensor_radam(
# Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size
torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
def radam(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
decoupled_weight_decay: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
):
r"""Functional API that performs RAdam algorithm computation.
See :class:`~torch.optim.RAdam` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_radam
else:
func = _single_tensor_radam
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
decoupled_weight_decay=decoupled_weight_decay,
differentiable=differentiable,
capturable=capturable,
has_complex=has_complex,
)

View File

@ -6,6 +6,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_scalar_dtype,
_maximize_doc,
@ -250,73 +251,6 @@ RMSprop.__doc__ = (
)
def rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rmsprop
else:
func = _single_tensor_rmsprop
func(
params,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)
def _single_tensor_rmsprop(
params: List[Tensor],
grads: List[Tensor],
@ -497,3 +431,71 @@ def _multi_tensor_rmsprop(
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
else:
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
def rmsprop(
params: List[Tensor],
grads: List[Tensor],
square_avgs: List[Tensor],
grad_avgs: List[Tensor],
momentum_buffer_list: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
maximize: bool = False,
differentiable: bool = False,
capturable: bool = False,
has_complex: bool = False,
*,
lr: float,
alpha: float,
eps: float,
weight_decay: float,
momentum: float,
centered: bool,
):
r"""Functional API that performs rmsprop algorithm computation.
See :class:`~torch.optim.RMSProp` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rmsprop
else:
func = _single_tensor_rmsprop
func(
params,
grads,
square_avgs,
grad_avgs,
momentum_buffer_list,
state_steps,
lr=lr,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
has_complex=has_complex,
)

View File

@ -6,6 +6,7 @@ from .optimizer import (
_capturable_doc,
_default_to_fused_or_foreach,
_differentiable_doc,
_disable_dynamo_if_unsupported,
_foreach_doc,
_get_scalar_dtype,
_maximize_doc,
@ -209,68 +210,6 @@ Rprop.__doc__ = (
)
def rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
maximize: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float,
):
r"""Functional API that performs rprop algorithm computation.
See :class:`~torch.optim.Rprop` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rprop
else:
func = _single_tensor_rprop
func(
params,
grads,
prevs,
step_sizes,
state_steps,
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
etaplus=etaplus,
capturable=capturable,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
)
def _single_tensor_rprop(
params: List[Tensor],
grads: List[Tensor],
@ -441,3 +380,66 @@ def _multi_tensor_rprop(
# Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
# basically already happened since we've been using grouped_prevs' memory to store
# updated grouped_grads!
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
def rprop(
params: List[Tensor],
grads: List[Tensor],
prevs: List[Tensor],
step_sizes: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
maximize: bool = False,
differentiable: bool = False,
has_complex: bool = False,
*,
step_size_min: float,
step_size_max: float,
etaminus: float,
etaplus: float,
):
r"""Functional API that performs rprop algorithm computation.
See :class:`~torch.optim.Rprop` for details.
"""
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if foreach is None:
_, foreach = _default_to_fused_or_foreach(
params, differentiable, use_fused=False
)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_rprop
else:
func = _single_tensor_rprop
func(
params,
grads,
prevs,
step_sizes,
state_steps,
step_size_min=step_size_min,
step_size_max=step_size_max,
etaminus=etaminus,
etaplus=etaplus,
capturable=capturable,
maximize=maximize,
differentiable=differentiable,
has_complex=has_complex,
)

View File

@ -1561,14 +1561,6 @@ optim_db: List[OptimizerInfo] = [
"TestOptimRenewed",
"test_complex_2d",
),
DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/116499"
),
"TestOptimRenewed",
"test_can_load_older_state_dict",
device_type="cuda",
),
DecorateInfo(
skipIfTorchDynamo(
"Errors, https://github.com/pytorch/pytorch/issues/117150"