Make Adam, AdamW work with nonzero-dim Tensor betas (#149939)

Fixes #147921

## Changes

- Convert tensor `betas` using `_to_scalar`
- Change annotation of `betas` param
- Change param type in docs

## Test Result

```bash
pytest -s test/test_optim.py -k test_tensor_lr -vv
```

![image](https://github.com/user-attachments/assets/312ee045-1e8b-4789-aa6e-ba63e6df7e81)

![image](https://github.com/user-attachments/assets/7e6ec274-645b-46b9-b1a6-2b340a685203)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149939
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
zeshengzong
2025-10-06 22:03:22 +00:00
committed by PyTorch MergeBot
parent 48b54b45d6
commit fdc8ccc5bc
4 changed files with 22 additions and 14 deletions

View File

@ -84,6 +84,7 @@ class Adam(Optimizer):
)
if betas[1].numel() != 1:
raise ValueError("Tensor betas[1] must be 1-element")
betas = tuple(map(_to_scalar, betas))
defaults = {
"lr": lr,
@ -315,8 +316,9 @@ Adam.__doc__ = (
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
coefficients used for computing running averages of gradient and
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
@ -375,7 +377,8 @@ def _single_tensor_adam(
assert isinstance(beta2, float)
else:
lr = _to_scalar(lr)
# TODO: Support nonzero-dim Tensor betas, see #147921
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
# We only shuffle around the beta when it is a Tensor, otherwise, we prefer
# treating it as a scalar.
@ -610,7 +613,8 @@ def _multi_tensor_adam(
assert not differentiable, "_foreach ops don't support autograd"
lr = _to_scalar(lr)
# TODO: Support nonzero-dim Tensor betas, see #147921
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
@ -801,8 +805,8 @@ def _fused_adam(
*,
amsgrad: bool,
has_complex: bool, # Needed for consistency.
beta1: float,
beta2: float,
beta1: Union[float, Tensor],
beta2: Union[float, Tensor],
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
@ -816,6 +820,9 @@ def _fused_adam(
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")
beta1 = _to_scalar(beta1)
beta2 = _to_scalar(beta2)
grad_scale_dict: DeviceDict = (
{grad_scale.device: grad_scale} if grad_scale is not None else {}
)
@ -905,8 +912,8 @@ def adam(
decoupled_weight_decay: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
beta1: Union[float, Tensor],
beta2: Union[float, Tensor],
lr: Union[float, Tensor],
weight_decay: float,
eps: float,

View File

@ -102,8 +102,9 @@ AdamW.__doc__ = (
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
coefficients used for computing running averages of gradient and
its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
@ -145,8 +146,8 @@ def adamw(
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
beta1: Union[float, Tensor],
beta2: Union[float, Tensor],
lr: Union[float, Tensor],
weight_decay: float,
eps: float,

View File

@ -230,7 +230,7 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]:
return capturable_supported_devices
def _to_scalar(x):
def _to_scalar(x: Union[float, torch.Tensor]):
r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor
if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is
kept as is.

View File

@ -528,7 +528,7 @@ def optim_inputs_func_adam(device, dtype=None):
params=None,
kwargs={
"lr": torch.tensor(0.001),
"betas": (torch.tensor(0.9), torch.tensor(0.99)),
"betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])),
"amsgrad": True,
"capturable": True,
},