mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Compare commits
2 Commits
v2.4.0-rc4
...
mlazos/mai
| Author | SHA1 | Date | |
|---|---|---|---|
| 3541de4b42 | |||
| 32220b239b |
@ -24,6 +24,9 @@ from .functional_sgd import _FunctionalSGD
|
||||
from .named_optimizer import _NamedOptimizer
|
||||
from .utils import as_functional_optim
|
||||
|
||||
from warnings import warn
|
||||
warn("TorchScript support for functional optimizers is"
|
||||
"deprecated and will be removed in a future PyTorch release. Consider using the torch.compile optimizer instead.")
|
||||
|
||||
# DistributedOptimizer imports torch.distributed.rpc names, so gate availability
|
||||
# based on RPC being available.
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_maximize_doc,
|
||||
@ -227,70 +228,6 @@ Adadelta.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
capturable: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lr: float,
|
||||
rho: float,
|
||||
eps: float,
|
||||
weight_decay: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs Adadelta algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adadelta` for details.
|
||||
"""
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# We still respect when the user inputs False for foreach.
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adadelta
|
||||
else:
|
||||
func = _single_tensor_adadelta
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
square_avgs,
|
||||
acc_deltas,
|
||||
state_steps,
|
||||
lr=lr,
|
||||
rho=rho,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -428,3 +365,68 @@ def _multi_tensor_adadelta(
|
||||
torch._foreach_add_(device_params, deltas)
|
||||
else:
|
||||
torch._foreach_add_(device_params, deltas, alpha=-lr)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
|
||||
def adadelta(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
acc_deltas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
capturable: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lr: float,
|
||||
rho: float,
|
||||
eps: float,
|
||||
weight_decay: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs Adadelta algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adadelta` for details.
|
||||
"""
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# We still respect when the user inputs False for foreach.
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adadelta
|
||||
else:
|
||||
func = _single_tensor_adadelta
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
square_avgs,
|
||||
acc_deltas,
|
||||
state_steps,
|
||||
lr=lr,
|
||||
rho=rho,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_dispatch_sqrt,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
@ -314,94 +315,6 @@ Adam.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs Adam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adam` for details.
|
||||
"""
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
if fused and not torch.jit.is_scripting():
|
||||
func = _fused_adam
|
||||
elif foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adam
|
||||
else:
|
||||
func = _single_tensor_adam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -767,3 +680,92 @@ def _fused_adam(
|
||||
torch._foreach_sub_(
|
||||
device_state_steps, [device_found_inf] * len(device_state_steps)
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
|
||||
def adam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs Adam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adam` for details.
|
||||
"""
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
if fused and not torch.jit.is_scripting():
|
||||
func = _fused_adam
|
||||
elif foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adam
|
||||
else:
|
||||
func = _single_tensor_adam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
@ -215,69 +216,6 @@ Adamax.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_infs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
eps: float,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
):
|
||||
r"""Functional API that performs adamax algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adamax` for details.
|
||||
"""
|
||||
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adamax
|
||||
else:
|
||||
func = _single_tensor_adamax
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_infs,
|
||||
state_steps,
|
||||
eps=eps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
capturable=capturable,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -452,3 +390,67 @@ def _multi_tensor_adamax(
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
|
||||
def adamax(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_infs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
eps: float,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
):
|
||||
r"""Functional API that performs adamax algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Adamax` for details.
|
||||
"""
|
||||
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adamax
|
||||
else:
|
||||
func = _single_tensor_adamax
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_infs,
|
||||
state_steps,
|
||||
eps=eps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
capturable=capturable,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_dispatch_sqrt,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
@ -315,92 +316,6 @@ AdamW.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs AdamW algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
if fused and not torch.jit.is_scripting():
|
||||
func = _fused_adamw
|
||||
elif foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adamw
|
||||
else:
|
||||
func = _single_tensor_adamw
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -760,3 +675,90 @@ def _fused_adamw(
|
||||
torch._foreach_sub_(
|
||||
device_state_steps, [device_found_inf] * len(device_state_steps)
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw)
|
||||
def adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs AdamW algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
if fused and not torch.jit.is_scripting():
|
||||
func = _fused_adamw
|
||||
elif foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adamw
|
||||
else:
|
||||
func = _single_tensor_adamw
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
@ -187,63 +188,6 @@ ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
|
||||
"""
|
||||
|
||||
|
||||
def asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
axs: List[Tensor],
|
||||
mus: List[Tensor],
|
||||
etas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lambd: float,
|
||||
lr: float,
|
||||
t0: float,
|
||||
alpha: float,
|
||||
weight_decay: float,
|
||||
):
|
||||
r"""Functional API that performs asgd algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.ASGD` for details.
|
||||
"""
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_asgd
|
||||
else:
|
||||
func = _single_tensor_asgd
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
axs,
|
||||
mus,
|
||||
etas,
|
||||
state_steps,
|
||||
lambd=lambd,
|
||||
lr=lr,
|
||||
t0=t0,
|
||||
alpha=alpha,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -441,3 +385,61 @@ def _multi_tensor_asgd(
|
||||
|
||||
torch._foreach_copy_(grouped_etas, new_etas)
|
||||
torch._foreach_copy_(grouped_mus, new_mus)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
|
||||
def asgd(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
axs: List[Tensor],
|
||||
mus: List[Tensor],
|
||||
etas: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lambd: float,
|
||||
lr: float,
|
||||
t0: float,
|
||||
alpha: float,
|
||||
weight_decay: float,
|
||||
):
|
||||
r"""Functional API that performs asgd algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.ASGD` for details.
|
||||
"""
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_asgd
|
||||
else:
|
||||
func = _single_tensor_asgd
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
axs,
|
||||
mus,
|
||||
etas,
|
||||
state_steps,
|
||||
lambd=lambd,
|
||||
lr=lr,
|
||||
t0=t0,
|
||||
alpha=alpha,
|
||||
weight_decay=weight_decay,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_dispatch_sqrt,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
@ -258,76 +259,6 @@ NAdam.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum_decay: float,
|
||||
eps: float,
|
||||
):
|
||||
r"""Functional API that performs NAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.NAdam` for details.
|
||||
"""
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in mu_products):
|
||||
raise RuntimeError(
|
||||
"API has changed, `mu_products` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_nadam
|
||||
else:
|
||||
func = _single_tensor_nadam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
mu_products,
|
||||
state_steps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
momentum_decay=momentum_decay,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
eps=eps,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -598,3 +529,74 @@ def _multi_tensor_nadam(
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_nadam)
|
||||
def nadam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
mu_products: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum_decay: float,
|
||||
eps: float,
|
||||
):
|
||||
r"""Functional API that performs NAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.NAdam` for details.
|
||||
"""
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in mu_products):
|
||||
raise RuntimeError(
|
||||
"API has changed, `mu_products` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_nadam
|
||||
else:
|
||||
func = _single_tensor_nadam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
mu_products,
|
||||
state_steps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
momentum_decay=momentum_decay,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
eps=eps,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -118,6 +118,27 @@ def _dispatch_sqrt(
|
||||
return math.sqrt(x)
|
||||
|
||||
|
||||
def _disable_dynamo_if_unsupported(single_tensor_fn=None):
|
||||
# workaround for torchscript BC
|
||||
# it requires all called functions to be in the
|
||||
# global environment at the site at which the
|
||||
# maybe_fallback closure is created
|
||||
if single_tensor_fn:
|
||||
globals()[single_tensor_fn.__name__] = single_tensor_fn
|
||||
|
||||
def wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def maybe_fallback(self, *args, **kwargs):
|
||||
if is_compiling() and not kwargs.get("capturable", False):
|
||||
return torch._disable_dynamo(func(self, *args, **kwargs))
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return maybe_fallback
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# For any optimizer with a faster implementation, we attempt to default to the
|
||||
# fastest + stablest whenever possible. For foreach, the requirements are to have
|
||||
# native params all on CUDA. For fused, there's currently the additional requirement
|
||||
|
||||
@ -7,6 +7,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_dispatch_sqrt,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
@ -234,67 +235,6 @@ RAdam.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
):
|
||||
r"""Functional API that performs RAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.RAdam` for details.
|
||||
"""
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_radam
|
||||
else:
|
||||
func = _single_tensor_radam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -559,3 +499,65 @@ def _multi_tensor_radam(
|
||||
|
||||
# Here, buffer = sqrt(1 - beta2^t) * rect_step_size / (sqrt(v) + eps) + unrect_step_size
|
||||
torch._foreach_addcmul_(grouped_params, grouped_exp_avgs, buffer)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_radam)
|
||||
def radam(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
decoupled_weight_decay: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
):
|
||||
r"""Functional API that performs RAdam algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.RAdam` for details.
|
||||
"""
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_radam
|
||||
else:
|
||||
func = _single_tensor_radam
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
differentiable=differentiable,
|
||||
capturable=capturable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_maximize_doc,
|
||||
@ -250,73 +251,6 @@ RMSprop.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
grad_avgs: List[Tensor],
|
||||
momentum_buffer_list: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lr: float,
|
||||
alpha: float,
|
||||
eps: float,
|
||||
weight_decay: float,
|
||||
momentum: float,
|
||||
centered: bool,
|
||||
):
|
||||
r"""Functional API that performs rmsprop algorithm computation.
|
||||
See :class:`~torch.optim.RMSProp` for details.
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_rmsprop
|
||||
else:
|
||||
func = _single_tensor_rmsprop
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
square_avgs,
|
||||
grad_avgs,
|
||||
momentum_buffer_list,
|
||||
state_steps,
|
||||
lr=lr,
|
||||
alpha=alpha,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
centered=centered,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -497,3 +431,71 @@ def _multi_tensor_rmsprop(
|
||||
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
|
||||
else:
|
||||
torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
|
||||
def rmsprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
square_avgs: List[Tensor],
|
||||
grad_avgs: List[Tensor],
|
||||
momentum_buffer_list: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
capturable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
lr: float,
|
||||
alpha: float,
|
||||
eps: float,
|
||||
weight_decay: float,
|
||||
momentum: float,
|
||||
centered: bool,
|
||||
):
|
||||
r"""Functional API that performs rmsprop algorithm computation.
|
||||
See :class:`~torch.optim.RMSProp` for details.
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_rmsprop
|
||||
else:
|
||||
func = _single_tensor_rmsprop
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
square_avgs,
|
||||
grad_avgs,
|
||||
momentum_buffer_list,
|
||||
state_steps,
|
||||
lr=lr,
|
||||
alpha=alpha,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
centered=centered,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_get_scalar_dtype,
|
||||
_maximize_doc,
|
||||
@ -209,68 +210,6 @@ Rprop.__doc__ = (
|
||||
)
|
||||
|
||||
|
||||
def rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
prevs: List[Tensor],
|
||||
step_sizes: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
etaminus: float,
|
||||
etaplus: float,
|
||||
):
|
||||
r"""Functional API that performs rprop algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Rprop` for details.
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_rprop
|
||||
else:
|
||||
func = _single_tensor_rprop
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
prevs,
|
||||
step_sizes,
|
||||
state_steps,
|
||||
step_size_min=step_size_min,
|
||||
step_size_max=step_size_max,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus,
|
||||
capturable=capturable,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
|
||||
def _single_tensor_rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
@ -441,3 +380,66 @@ def _multi_tensor_rprop(
|
||||
# Logically, you may expect grouped_prevs to get updated to grouped_grads, but that's
|
||||
# basically already happened since we've been using grouped_prevs' memory to store
|
||||
# updated grouped_grads!
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rprop)
|
||||
def rprop(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
prevs: List[Tensor],
|
||||
step_sizes: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
maximize: bool = False,
|
||||
differentiable: bool = False,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
etaminus: float,
|
||||
etaplus: float,
|
||||
):
|
||||
r"""Functional API that performs rprop algorithm computation.
|
||||
|
||||
See :class:`~torch.optim.Rprop` for details.
|
||||
"""
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_rprop
|
||||
else:
|
||||
func = _single_tensor_rprop
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
prevs,
|
||||
step_sizes,
|
||||
state_steps,
|
||||
step_size_min=step_size_min,
|
||||
step_size_max=step_size_max,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus,
|
||||
capturable=capturable,
|
||||
maximize=maximize,
|
||||
differentiable=differentiable,
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
@ -1561,14 +1561,6 @@ optim_db: List[OptimizerInfo] = [
|
||||
"TestOptimRenewed",
|
||||
"test_complex_2d",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"See https://github.com/pytorch/pytorch/issues/116499"
|
||||
),
|
||||
"TestOptimRenewed",
|
||||
"test_can_load_older_state_dict",
|
||||
device_type="cuda",
|
||||
),
|
||||
DecorateInfo(
|
||||
skipIfTorchDynamo(
|
||||
"Errors, https://github.com/pytorch/pytorch/issues/117150"
|
||||
|
||||
Reference in New Issue
Block a user