mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	There are 31 places that I spotted which construct literal dictionaries.
This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
		
	
		
			
				
	
	
		
			540 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			540 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-defs
 | |
| r"""Implementation for the RMSprop algorithm."""
 | |
| 
 | |
| from typing import cast, Optional, Union
 | |
| 
 | |
| import torch
 | |
| from torch import Tensor
 | |
| 
 | |
| from .optimizer import (
 | |
|     _capturable_doc,
 | |
|     _default_to_fused_or_foreach,
 | |
|     _differentiable_doc,
 | |
|     _disable_dynamo_if_unsupported,
 | |
|     _foreach_doc,
 | |
|     _get_capturable_supported_devices,
 | |
|     _get_scalar_dtype,
 | |
|     _maximize_doc,
 | |
|     _params_doc,
 | |
|     _to_scalar,
 | |
|     _use_grad_for_differentiable,
 | |
|     _view_as_real,
 | |
|     Optimizer,
 | |
|     ParamsT,
 | |
| )
 | |
| 
 | |
| 
 | |
| __all__ = ["RMSprop", "rmsprop"]
 | |
| 
 | |
| 
 | |
| class RMSprop(Optimizer):  # noqa: D101
 | |
|     def __init__(
 | |
|         self,
 | |
|         params: ParamsT,
 | |
|         lr: Union[float, Tensor] = 1e-2,
 | |
|         alpha: float = 0.99,
 | |
|         eps: float = 1e-8,
 | |
|         weight_decay: float = 0,
 | |
|         momentum: float = 0,
 | |
|         centered: bool = False,
 | |
|         capturable: bool = False,
 | |
|         foreach: Optional[bool] = None,
 | |
|         maximize: bool = False,
 | |
|         differentiable: bool = False,
 | |
|     ):  # noqa: D107
 | |
|         if isinstance(lr, Tensor) and lr.numel() != 1:
 | |
|             raise ValueError("Tensor lr must be 1-element")
 | |
|         if not 0.0 <= lr:
 | |
|             raise ValueError(f"Invalid learning rate: {lr}")
 | |
|         if not 0.0 <= eps:
 | |
|             raise ValueError(f"Invalid epsilon value: {eps}")
 | |
|         if not 0.0 <= momentum:
 | |
|             raise ValueError(f"Invalid momentum value: {momentum}")
 | |
|         if not 0.0 <= weight_decay:
 | |
|             raise ValueError(f"Invalid weight_decay value: {weight_decay}")
 | |
|         if not 0.0 <= alpha:
 | |
|             raise ValueError(f"Invalid alpha value: {alpha}")
 | |
| 
 | |
|         defaults = {
 | |
|             "lr": lr,
 | |
|             "momentum": momentum,
 | |
|             "alpha": alpha,
 | |
|             "eps": eps,
 | |
|             "centered": centered,
 | |
|             "weight_decay": weight_decay,
 | |
|             "capturable": capturable,
 | |
|             "foreach": foreach,
 | |
|             "maximize": maximize,
 | |
|             "differentiable": differentiable,
 | |
|         }
 | |
|         super().__init__(params, defaults)
 | |
| 
 | |
|     def __setstate__(self, state):  # noqa: D105
 | |
|         super().__setstate__(state)
 | |
|         for group in self.param_groups:
 | |
|             group.setdefault("momentum", 0)
 | |
|             group.setdefault("centered", False)
 | |
|             group.setdefault("foreach", None)
 | |
|             group.setdefault("maximize", False)
 | |
|             group.setdefault("differentiable", False)
 | |
|             group.setdefault("capturable", False)
 | |
|             for p in group["params"]:
 | |
|                 p_state = self.state.get(p, [])
 | |
|                 if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
 | |
|                     step_val = float(p_state["step"])
 | |
|                     p_state["step"] = (
 | |
|                         torch.tensor(
 | |
|                             step_val, dtype=_get_scalar_dtype(), device=p.device
 | |
|                         )
 | |
|                         if group["capturable"]
 | |
|                         else torch.tensor(step_val, dtype=_get_scalar_dtype())
 | |
|                     )
 | |
| 
 | |
|     def _init_group(
 | |
|         self,
 | |
|         group,
 | |
|         params_with_grad,
 | |
|         grads,
 | |
|         square_avgs,
 | |
|         momentum_buffer_list,
 | |
|         grad_avgs,
 | |
|         state_steps,
 | |
|     ):
 | |
|         has_complex = False
 | |
|         for p in group["params"]:
 | |
|             if p.grad is None:
 | |
|                 continue
 | |
|             has_complex |= torch.is_complex(p)
 | |
|             params_with_grad.append(p)
 | |
| 
 | |
|             if p.grad.is_sparse:
 | |
|                 raise RuntimeError("RMSprop does not support sparse gradients")
 | |
|             grads.append(p.grad)
 | |
| 
 | |
|             state = self.state[p]
 | |
| 
 | |
|             # State initialization
 | |
|             if len(state) == 0:
 | |
|                 state["step"] = (
 | |
|                     torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
 | |
|                     if group["capturable"]
 | |
|                     else torch.zeros((), dtype=_get_scalar_dtype())
 | |
|                 )
 | |
|                 state["square_avg"] = torch.zeros_like(
 | |
|                     p, memory_format=torch.preserve_format
 | |
|                 )
 | |
|                 if group["momentum"] > 0:
 | |
|                     state["momentum_buffer"] = torch.zeros_like(
 | |
|                         p, memory_format=torch.preserve_format
 | |
|                     )
 | |
|                 if group["centered"]:
 | |
|                     state["grad_avg"] = torch.zeros_like(
 | |
|                         p, memory_format=torch.preserve_format
 | |
|                     )
 | |
|             square_avgs.append(state["square_avg"])
 | |
|             state_steps.append(state["step"])
 | |
| 
 | |
|             if group["momentum"] > 0:
 | |
|                 momentum_buffer_list.append(state["momentum_buffer"])
 | |
|             if group["centered"]:
 | |
|                 grad_avgs.append(state["grad_avg"])
 | |
| 
 | |
|         return has_complex
 | |
| 
 | |
|     @_use_grad_for_differentiable
 | |
|     def step(self, closure=None):
 | |
|         """Perform a single optimization step.
 | |
| 
 | |
|         Args:
 | |
|             closure (Callable, optional): A closure that reevaluates the model
 | |
|                 and returns the loss.
 | |
|         """
 | |
|         self._cuda_graph_capture_health_check()
 | |
| 
 | |
|         loss = None
 | |
|         if closure is not None:
 | |
|             with torch.enable_grad():
 | |
|                 loss = closure()
 | |
| 
 | |
|         for group in self.param_groups:
 | |
|             params_with_grad: list[Tensor] = []
 | |
|             grads: list[Tensor] = []
 | |
|             square_avgs: list[Tensor] = []
 | |
|             grad_avgs: list[Tensor] = []
 | |
|             momentum_buffer_list: list[Tensor] = []
 | |
|             state_steps: list[Tensor] = []
 | |
| 
 | |
|             has_complex = self._init_group(
 | |
|                 group,
 | |
|                 params_with_grad,
 | |
|                 grads,
 | |
|                 square_avgs,
 | |
|                 momentum_buffer_list,
 | |
|                 grad_avgs,
 | |
|                 state_steps,
 | |
|             )
 | |
| 
 | |
|             rmsprop(
 | |
|                 params_with_grad,
 | |
|                 grads,
 | |
|                 square_avgs,
 | |
|                 grad_avgs,
 | |
|                 momentum_buffer_list,
 | |
|                 state_steps,
 | |
|                 lr=group["lr"],
 | |
|                 alpha=group["alpha"],
 | |
|                 eps=group["eps"],
 | |
|                 weight_decay=group["weight_decay"],
 | |
|                 momentum=group["momentum"],
 | |
|                 centered=group["centered"],
 | |
|                 foreach=group["foreach"],
 | |
|                 maximize=group["maximize"],
 | |
|                 differentiable=group["differentiable"],
 | |
|                 capturable=group["capturable"],
 | |
|                 has_complex=has_complex,
 | |
|             )
 | |
| 
 | |
|         return loss
 | |
| 
 | |
| 
 | |
| RMSprop.__doc__ = (
 | |
|     r"""Implements RMSprop algorithm.
 | |
| 
 | |
|     .. math::
 | |
|        \begin{aligned}
 | |
|             &\rule{110mm}{0.4pt}                                                                 \\
 | |
|             &\textbf{input}      : \alpha \text{ (alpha)}, \: \gamma \text{ (lr)},
 | |
|                 \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}                   \\
 | |
|             &\hspace{13mm}   \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},
 | |
|                 \: centered, \: \epsilon \text{ (epsilon)}                                       \\
 | |
|             &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
 | |
|                 \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0     \\[-1.ex]
 | |
|             &\rule{110mm}{0.4pt}                                                                 \\
 | |
|             &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
 | |
|             &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
 | |
|             &\hspace{5mm}if \: \lambda \neq 0                                                    \\
 | |
|             &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
 | |
|             &\hspace{5mm}v_t           \leftarrow   \alpha v_{t-1} + (1 - \alpha) g^2_t
 | |
|                 \hspace{8mm}                                                                     \\
 | |
|             &\hspace{5mm} \tilde{v_t} \leftarrow v_t                                             \\
 | |
|             &\hspace{5mm}if \: centered                                                          \\
 | |
|             &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t            \\
 | |
|             &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} -  \big(g^{ave}_{t} \big)^2        \\
 | |
|             &\hspace{5mm}if \: \mu > 0                                                           \\
 | |
|             &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
 | |
|                 g_t/ \big(\sqrt{\tilde{v_t}} +  \epsilon \big)                                   \\
 | |
|             &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t                \\
 | |
|             &\hspace{5mm} else                                                                   \\
 | |
|             &\hspace{10mm}\theta_t      \leftarrow   \theta_{t-1} -
 | |
|                 \gamma  g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big)  \hspace{3mm}              \\
 | |
|             &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
 | |
|             &\bf{return} \:  \theta_t                                                     \\[-1.ex]
 | |
|             &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
 | |
|        \end{aligned}
 | |
| 
 | |
|     For further details regarding the algorithm we refer to
 | |
|     `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
 | |
|     and centered version `Generating Sequences
 | |
|     With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
 | |
|     The implementation here takes the square root of the gradient average before
 | |
|     adding epsilon (note that TensorFlow interchanges these two operations). The effective
 | |
|     learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
 | |
|     is the scheduled learning rate and :math:`v` is the weighted moving average
 | |
|     of the squared gradient.
 | |
|     """
 | |
|     + rf"""
 | |
|     Args:
 | |
|         {_params_doc}
 | |
|         lr (float, Tensor, optional): learning rate (default: 1e-2)
 | |
|         alpha (float, optional): smoothing constant (default: 0.99)
 | |
|         eps (float, optional): term added to the denominator to improve
 | |
|             numerical stability (default: 1e-8)
 | |
|         weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
 | |
|         momentum (float, optional): momentum factor (default: 0)
 | |
|         centered (bool, optional) : if ``True``, compute the centered RMSProp,
 | |
|             the gradient is normalized by an estimation of its variance
 | |
|         {_capturable_doc}
 | |
|         {_foreach_doc}
 | |
|         {_maximize_doc}
 | |
|         {_differentiable_doc}
 | |
| 
 | |
|     """
 | |
| )
 | |
| 
 | |
| 
 | |
| def _single_tensor_rmsprop(
 | |
|     params: list[Tensor],
 | |
|     grads: list[Tensor],
 | |
|     square_avgs: list[Tensor],
 | |
|     grad_avgs: list[Tensor],
 | |
|     momentum_buffer_list: list[Tensor],
 | |
|     state_steps: list[Tensor],
 | |
|     *,
 | |
|     lr: float,
 | |
|     alpha: float,
 | |
|     eps: float,
 | |
|     weight_decay: float,
 | |
|     momentum: float,
 | |
|     centered: bool,
 | |
|     maximize: bool,
 | |
|     differentiable: bool,
 | |
|     capturable: bool,
 | |
|     has_complex: bool,
 | |
| ):
 | |
|     if not torch.jit.is_scripting():
 | |
|         lr = _to_scalar(lr)
 | |
| 
 | |
|     for i, param in enumerate(params):
 | |
|         step = state_steps[i]
 | |
| 
 | |
|         # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
 | |
|         if not torch.compiler.is_compiling() and capturable:
 | |
|             capturable_supported_devices = _get_capturable_supported_devices()
 | |
|             assert (
 | |
|                 param.device.type == step.device.type
 | |
|                 and param.device.type in capturable_supported_devices
 | |
|             ), (
 | |
|                 f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
 | |
|             )
 | |
| 
 | |
|         grad = grads[i]
 | |
|         grad = grad if not maximize else -grad
 | |
|         square_avg = square_avgs[i]
 | |
| 
 | |
|         step += 1
 | |
| 
 | |
|         if weight_decay != 0:
 | |
|             grad = grad.add(param, alpha=weight_decay)
 | |
| 
 | |
|         is_complex_param = torch.is_complex(param)
 | |
|         if is_complex_param:
 | |
|             param = torch.view_as_real(param)
 | |
|             grad = torch.view_as_real(grad)
 | |
|             square_avg = torch.view_as_real(square_avg)
 | |
| 
 | |
|         square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
 | |
| 
 | |
|         if centered:
 | |
|             grad_avg = grad_avgs[i]
 | |
|             if is_complex_param:
 | |
|                 grad_avg = torch.view_as_real(grad_avg)
 | |
|             grad_avg.lerp_(grad, 1 - alpha)
 | |
|             avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
 | |
|         else:
 | |
|             avg = square_avg.sqrt()
 | |
| 
 | |
|         if differentiable:
 | |
|             avg = avg.add(eps)
 | |
|         else:
 | |
|             avg = avg.add_(eps)
 | |
| 
 | |
|         if momentum > 0:
 | |
|             buf = momentum_buffer_list[i]
 | |
|             if is_complex_param:
 | |
|                 buf = torch.view_as_real(buf)
 | |
|             buf.mul_(momentum).addcdiv_(grad, avg)
 | |
|             param.add_(buf, alpha=-lr)
 | |
|         else:
 | |
|             param.addcdiv_(grad, avg, value=-lr)
 | |
| 
 | |
| 
 | |
| def _multi_tensor_rmsprop(
 | |
|     params: list[Tensor],
 | |
|     grads: list[Tensor],
 | |
|     square_avgs: list[Tensor],
 | |
|     grad_avgs: list[Tensor],
 | |
|     momentum_buffer_list: list[Tensor],
 | |
|     state_steps: list[Tensor],
 | |
|     *,
 | |
|     lr: float,
 | |
|     alpha: float,
 | |
|     eps: float,
 | |
|     weight_decay: float,
 | |
|     momentum: float,
 | |
|     centered: bool,
 | |
|     maximize: bool,
 | |
|     differentiable: bool,
 | |
|     capturable: bool,
 | |
|     has_complex: bool,
 | |
| ):
 | |
|     if len(params) == 0:
 | |
|         return
 | |
| 
 | |
|     assert not differentiable, "_foreach ops don't support autograd"
 | |
| 
 | |
|     # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
 | |
|     if not torch.compiler.is_compiling() and capturable:
 | |
|         capturable_supported_devices = _get_capturable_supported_devices()
 | |
|         assert all(
 | |
|             p.device.type == step.device.type
 | |
|             and p.device.type in capturable_supported_devices
 | |
|             for p, step in zip(params, state_steps)
 | |
|         ), (
 | |
|             f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
 | |
|         )
 | |
| 
 | |
|     lr = _to_scalar(lr)
 | |
| 
 | |
|     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
 | |
|         [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps]  # type: ignore[list-item]
 | |
|     )
 | |
|     for (
 | |
|         (
 | |
|             grouped_params_,
 | |
|             grouped_grads_,
 | |
|             grouped_square_avgs_,
 | |
|             grouped_grad_avgs_,
 | |
|             grouped_momentum_buffer_list_,
 | |
|             grouped_state_steps_,
 | |
|         )
 | |
|     ), _ in grouped_tensors.values():
 | |
|         grouped_params = cast(list[Tensor], grouped_params_)
 | |
|         grouped_grads = cast(list[Tensor], grouped_grads_)
 | |
|         grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_)
 | |
|         grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
 | |
| 
 | |
|         if has_complex:
 | |
|             state_and_grads = [grouped_grads, grouped_square_avgs]
 | |
|             if momentum > 0:
 | |
|                 grouped_momentum_buffer_list = cast(
 | |
|                     list[Tensor], grouped_momentum_buffer_list_
 | |
|                 )
 | |
|                 state_and_grads.append(grouped_momentum_buffer_list)
 | |
|             if centered:
 | |
|                 grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
 | |
|                 state_and_grads.append(grouped_grad_avgs)
 | |
|             _view_as_real(grouped_params, *state_and_grads)
 | |
| 
 | |
|         if maximize:
 | |
|             grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
 | |
| 
 | |
|         # Update steps
 | |
|         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
 | |
|         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
 | |
|         # wrapped it once now. The alpha is required to assure we go to the right overload.
 | |
|         if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
 | |
|             torch._foreach_add_(
 | |
|                 grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
 | |
|             )
 | |
|         else:
 | |
|             torch._foreach_add_(grouped_state_steps, 1)
 | |
| 
 | |
|         if weight_decay != 0:
 | |
|             # Reuse the intermediate memory (grouped_grads) already allocated for maximize
 | |
|             if maximize:
 | |
|                 torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
 | |
|             else:
 | |
|                 grouped_grads = torch._foreach_add(  # type: ignore[assignment]
 | |
|                     grouped_grads, grouped_params, alpha=weight_decay
 | |
|                 )
 | |
| 
 | |
|         torch._foreach_mul_(grouped_square_avgs, alpha)
 | |
|         torch._foreach_addcmul_(
 | |
|             grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
 | |
|         )
 | |
| 
 | |
|         if centered:
 | |
|             grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
 | |
|             torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
 | |
|             avg = torch._foreach_addcmul(
 | |
|                 grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
 | |
|             )
 | |
|             torch._foreach_sqrt_(avg)
 | |
|             torch._foreach_add_(avg, eps)
 | |
|         else:
 | |
|             avg = torch._foreach_sqrt(grouped_square_avgs)
 | |
|             torch._foreach_add_(avg, eps)
 | |
| 
 | |
|         if momentum > 0:
 | |
|             grouped_momentum_buffer_list = cast(
 | |
|                 list[Tensor], grouped_momentum_buffer_list_
 | |
|             )
 | |
|             torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
 | |
|             torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
 | |
|             # If LR is a tensor, the else branch will internally call item()
 | |
|             # which will cause silent incorrectness if we are capturing
 | |
|             if capturable and isinstance(lr, torch.Tensor):
 | |
|                 momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
 | |
|                 torch._foreach_add_(grouped_params, momentum_lr)
 | |
|             else:
 | |
|                 torch._foreach_add_(
 | |
|                     grouped_params, grouped_momentum_buffer_list, alpha=-lr
 | |
|                 )
 | |
|         else:
 | |
|             # If LR is a tensor, the else branch will internally call item()
 | |
|             # which will cause silent incorrectness if we are capturing
 | |
|             if capturable and isinstance(lr, torch.Tensor):
 | |
|                 torch._foreach_div_(avg, -lr)
 | |
|                 torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
 | |
|             else:
 | |
|                 torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
 | |
| 
 | |
| 
 | |
| @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
 | |
| def rmsprop(
 | |
|     params: list[Tensor],
 | |
|     grads: list[Tensor],
 | |
|     square_avgs: list[Tensor],
 | |
|     grad_avgs: list[Tensor],
 | |
|     momentum_buffer_list: list[Tensor],
 | |
|     state_steps: list[Tensor],
 | |
|     # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
 | |
|     # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
 | |
|     foreach: Optional[bool] = None,
 | |
|     maximize: bool = False,
 | |
|     differentiable: bool = False,
 | |
|     capturable: bool = False,
 | |
|     has_complex: bool = False,
 | |
|     *,
 | |
|     lr: float,
 | |
|     alpha: float,
 | |
|     eps: float,
 | |
|     weight_decay: float,
 | |
|     momentum: float,
 | |
|     centered: bool,
 | |
| ):
 | |
|     r"""Functional API that performs rmsprop algorithm computation.
 | |
| 
 | |
|     See :class:`~torch.optim.RMSProp` for details.
 | |
|     """
 | |
|     # this check is slow during compilation, so we skip it
 | |
|     # if it's strictly needed we can add this check back in dynamo
 | |
|     if not torch.compiler.is_compiling() and not all(
 | |
|         isinstance(t, torch.Tensor) for t in state_steps
 | |
|     ):
 | |
|         raise RuntimeError(
 | |
|             "API has changed, `state_steps` argument must contain a list of singleton tensors"
 | |
|         )
 | |
| 
 | |
|     if foreach is None:
 | |
|         _, foreach = _default_to_fused_or_foreach(
 | |
|             params, differentiable, use_fused=False
 | |
|         )
 | |
| 
 | |
|     if foreach and torch.jit.is_scripting():
 | |
|         raise RuntimeError("torch.jit.script not supported with foreach optimizers")
 | |
| 
 | |
|     if foreach and not torch.jit.is_scripting():
 | |
|         func = _multi_tensor_rmsprop
 | |
|     else:
 | |
|         func = _single_tensor_rmsprop
 | |
| 
 | |
|     func(
 | |
|         params,
 | |
|         grads,
 | |
|         square_avgs,
 | |
|         grad_avgs,
 | |
|         momentum_buffer_list,
 | |
|         state_steps,
 | |
|         lr=lr,
 | |
|         alpha=alpha,
 | |
|         eps=eps,
 | |
|         weight_decay=weight_decay,
 | |
|         momentum=momentum,
 | |
|         centered=centered,
 | |
|         maximize=maximize,
 | |
|         capturable=capturable,
 | |
|         differentiable=differentiable,
 | |
|         has_complex=has_complex,
 | |
|     )
 |