diff --git a/torch/optim/adam.py b/torch/optim/adam.py index f6c70cc349fb..fab8e18c4310 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -84,6 +84,7 @@ class Adam(Optimizer): ) if betas[1].numel() != 1: raise ValueError("Tensor betas[1] must be 1-element") + betas = tuple(map(_to_scalar, betas)) defaults = { "lr": lr, @@ -315,8 +316,9 @@ Adam.__doc__ = ( lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) + betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional): + coefficients used for computing running averages of gradient and + its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) @@ -375,7 +377,8 @@ def _single_tensor_adam( assert isinstance(beta2, float) else: lr = _to_scalar(lr) - # TODO: Support nonzero-dim Tensor betas, see #147921 + beta1 = _to_scalar(beta1) + beta2 = _to_scalar(beta2) # We only shuffle around the beta when it is a Tensor, otherwise, we prefer # treating it as a scalar. @@ -610,7 +613,8 @@ def _multi_tensor_adam( assert not differentiable, "_foreach ops don't support autograd" lr = _to_scalar(lr) - # TODO: Support nonzero-dim Tensor betas, see #147921 + beta1 = _to_scalar(beta1) + beta2 = _to_scalar(beta2) grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] @@ -801,8 +805,8 @@ def _fused_adam( *, amsgrad: bool, has_complex: bool, # Needed for consistency. - beta1: float, - beta2: float, + beta1: Union[float, Tensor], + beta2: Union[float, Tensor], lr: Union[float, Tensor], weight_decay: float, eps: float, @@ -816,6 +820,9 @@ def _fused_adam( if differentiable: raise RuntimeError("Adam with fused=True does not support differentiable=True") + beta1 = _to_scalar(beta1) + beta2 = _to_scalar(beta2) + grad_scale_dict: DeviceDict = ( {grad_scale.device: grad_scale} if grad_scale is not None else {} ) @@ -905,8 +912,8 @@ def adam( decoupled_weight_decay: bool = False, *, amsgrad: bool, - beta1: float, - beta2: float, + beta1: Union[float, Tensor], + beta2: Union[float, Tensor], lr: Union[float, Tensor], weight_decay: float, eps: float, diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index b61a3f61b668..0558cbddd883 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -102,8 +102,9 @@ AdamW.__doc__ = ( lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) + betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional): + coefficients used for computing running averages of gradient and + its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay coefficient (default: 1e-2) @@ -145,8 +146,8 @@ def adamw( has_complex: bool = False, *, amsgrad: bool, - beta1: float, - beta2: float, + beta1: Union[float, Tensor], + beta2: Union[float, Tensor], lr: Union[float, Tensor], weight_decay: float, eps: float, diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 0dd36d93d012..6f12315e78c0 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -230,7 +230,7 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> list[str]: return capturable_supported_devices -def _to_scalar(x): +def _to_scalar(x: Union[float, torch.Tensor]): r"""This function converts a hyperparameter to a 0-dimension (scalar) tensor if it is a nonzero-dimensions 1-element tensor. If it is not a tensor, it is kept as is. diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 817b11b3a7a3..5036fb54cdc6 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -528,7 +528,7 @@ def optim_inputs_func_adam(device, dtype=None): params=None, kwargs={ "lr": torch.tensor(0.001), - "betas": (torch.tensor(0.9), torch.tensor(0.99)), + "betas": (torch.tensor([[[0.9]]]), torch.tensor([[0.99]])), "amsgrad": True, "capturable": True, },