mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Make Adam
, AdamW
work with nonzero-dim Tensor betas (#149939)
Fixes #147921 ## Changes - Convert tensor `betas` using `_to_scalar` - Change annotation of `betas` param - Change param type in docs ## Test Result ```bash pytest -s test/test_optim.py -k test_tensor_lr -vv ```   Pull Request resolved: https://github.com/pytorch/pytorch/pull/149939 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
48b54b45d6
commit
fdc8ccc5bc
@ -84,6 +84,7 @@ class Adam(Optimizer):
|
||||
)
|
||||
if betas[1].numel() != 1:
|
||||
raise ValueError("Tensor betas[1] must be 1-element")
|
||||
betas = tuple(map(_to_scalar, betas))
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
@ -315,8 +316,9 @@ Adam.__doc__ = (
|
||||
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
|
||||
is not yet supported for all our implementations. Please use a float
|
||||
LR if you are not also specifying fused=True or capturable=True.
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
|
||||
coefficients used for computing running averages of gradient and
|
||||
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
@ -375,7 +377,8 @@ def _single_tensor_adam(
|
||||
assert isinstance(beta2, float)
|
||||
else:
|
||||
lr = _to_scalar(lr)
|
||||
# TODO: Support nonzero-dim Tensor betas, see #147921
|
||||
beta1 = _to_scalar(beta1)
|
||||
beta2 = _to_scalar(beta2)
|
||||
|
||||
# We only shuffle around the beta when it is a Tensor, otherwise, we prefer
|
||||
# treating it as a scalar.
|
||||
@ -610,7 +613,8 @@ def _multi_tensor_adam(
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
# TODO: Support nonzero-dim Tensor betas, see #147921
|
||||
beta1 = _to_scalar(beta1)
|
||||
beta2 = _to_scalar(beta2)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
@ -801,8 +805,8 @@ def _fused_adam(
|
||||
*,
|
||||
amsgrad: bool,
|
||||
has_complex: bool, # Needed for consistency.
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
beta1: Union[float, Tensor],
|
||||
beta2: Union[float, Tensor],
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
@ -816,6 +820,9 @@ def _fused_adam(
|
||||
if differentiable:
|
||||
raise RuntimeError("Adam with fused=True does not support differentiable=True")
|
||||
|
||||
beta1 = _to_scalar(beta1)
|
||||
beta2 = _to_scalar(beta2)
|
||||
|
||||
grad_scale_dict: DeviceDict = (
|
||||
{grad_scale.device: grad_scale} if grad_scale is not None else {}
|
||||
)
|
||||
@ -905,8 +912,8 @@ def adam(
|
||||
decoupled_weight_decay: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
beta1: Union[float, Tensor],
|
||||
beta2: Union[float, Tensor],
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
|
@ -102,8 +102,9 @@ AdamW.__doc__ = (
|
||||
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
|
||||
is not yet supported for all our implementations. Please use a float
|
||||
LR if you are not also specifying fused=True or capturable=True.
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
|
||||
coefficients used for computing running averages of gradient and
|
||||
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
@ -145,8 +146,8 @@ def adamw(
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
beta1: Union[float, Tensor],
|
||||
beta2: Union[float, Tensor],
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
|
@ -230,7 +230,7 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]:
|
||||
return capturable_supported_devices
|
||||
|
||||
|
||||
def _to_scalar(x):
|
||||
def _to_scalar(x: Union[float, torch.Tensor]):
|
||||
r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor
|
||||
if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is
|
||||
kept as is.
|
||||
|
@ -528,7 +528,7 @@ def optim_inputs_func_adam(device, dtype=None):
|
||||
params=None,
|
||||
kwargs={
|
||||
"lr": torch.tensor(0.001),
|
||||
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
|
||||
"betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])),
|
||||
"amsgrad": True,
|
||||
"capturable": True,
|
||||
},
|
||||
|
Reference in New Issue
Block a user