mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1 after: 0 errors (2,753 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615 Approved by: https://github.com/oulgen
364 lines
13 KiB
Python
364 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
||
# mypy: disable-error-code=arg-type
|
||
"""Implementation of the Muon optimizer."""
|
||
|
||
import math
|
||
from collections.abc import MutableMapping
|
||
from typing import Optional
|
||
|
||
import torch
|
||
from torch import Tensor
|
||
|
||
from .optimizer import (
|
||
_disable_dynamo_if_unsupported,
|
||
_params_doc,
|
||
_to_scalar,
|
||
Optimizer,
|
||
ParamsT,
|
||
)
|
||
|
||
|
||
__all__ = ["Muon"]
|
||
|
||
# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
|
||
# github permlink: https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L16
|
||
EPS = 1e-7
|
||
DEFAULT_A = 3.4445
|
||
DEFAULT_B = -4.7750
|
||
DEFAULT_C = 2.0315
|
||
DEFAULT_NS_STEPS = 5
|
||
|
||
|
||
def _zeropower_via_newtonschulz(
|
||
grad: Tensor, ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float
|
||
) -> Tensor:
|
||
"""
|
||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||
|
||
Implementation reference: https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||
with suggestions by @jxbz, @leloykun, and @YouJiacheng.
|
||
"""
|
||
if ns_steps >= 100:
|
||
raise ValueError(
|
||
"Number of steps must be less than 100 for computational efficiency"
|
||
)
|
||
if len(grad.shape) != 2:
|
||
raise ValueError("Input tensor gradient must be a 2D matrix")
|
||
if len(ns_coefficients) != 3:
|
||
raise ValueError("Coefficients must be a tuple of exactly 3 values")
|
||
a, b, c = ns_coefficients
|
||
ortho_grad = grad.bfloat16()
|
||
if grad.size(0) > grad.size(1):
|
||
ortho_grad = ortho_grad.T
|
||
# Ensure spectral norm is at most 1
|
||
ortho_grad.div_(ortho_grad.norm().clamp(min=eps))
|
||
# Perform the NS iterations
|
||
for _ in range(ns_steps):
|
||
gram_matrix = ortho_grad @ ortho_grad.T
|
||
gram_update = torch.addmm(
|
||
gram_matrix, gram_matrix, gram_matrix, beta=b, alpha=c
|
||
)
|
||
ortho_grad = torch.addmm(ortho_grad, gram_update, ortho_grad, beta=a)
|
||
|
||
if grad.size(0) > grad.size(1):
|
||
ortho_grad = ortho_grad.T
|
||
return ortho_grad
|
||
|
||
|
||
def _adjust_lr(
|
||
lr: float, adjust_lr_fn: Optional[str], param_shape: torch.Size
|
||
) -> float:
|
||
"""Default learning rate adjustment used by Muon."""
|
||
A, B = param_shape[:2]
|
||
|
||
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
||
# pyrefly: ignore # no-matching-overload
|
||
adjusted_ratio = math.sqrt(max(1, A / B))
|
||
elif adjust_lr_fn == "match_rms_adamw":
|
||
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
||
else:
|
||
adjusted_ratio = 1.0
|
||
return lr * adjusted_ratio
|
||
|
||
|
||
class Muon(Optimizer):
|
||
def __init__(
|
||
self,
|
||
params: ParamsT,
|
||
lr: float = 1e-3,
|
||
weight_decay: float = 0.1,
|
||
momentum: float = 0.95,
|
||
nesterov: bool = True,
|
||
ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C),
|
||
eps: float = EPS,
|
||
ns_steps: int = DEFAULT_NS_STEPS,
|
||
adjust_lr_fn: Optional[str] = None,
|
||
) -> None:
|
||
if isinstance(lr, Tensor) and lr.numel() != 1:
|
||
raise ValueError("Tensor lr must be 1-element")
|
||
if not 0.0 <= lr:
|
||
raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
|
||
if not 0.0 <= momentum:
|
||
raise ValueError(f"momentum should be >= 0 but is: {momentum}")
|
||
if not 0.0 <= weight_decay:
|
||
raise ValueError(f"weight decay should be >= 0 but is: {weight_decay}")
|
||
if adjust_lr_fn is not None and adjust_lr_fn not in [
|
||
"original",
|
||
"match_rms_adamw",
|
||
]:
|
||
raise ValueError(
|
||
f"Adjust learning rate function {adjust_lr_fn} is not supported"
|
||
)
|
||
|
||
defaults = {
|
||
"lr": lr,
|
||
"weight_decay": weight_decay,
|
||
"momentum": momentum,
|
||
"nesterov": nesterov,
|
||
"ns_coefficients": ns_coefficients,
|
||
"eps": eps,
|
||
"ns_steps": ns_steps,
|
||
"adjust_lr_fn": adjust_lr_fn,
|
||
}
|
||
super().__init__(params, defaults)
|
||
|
||
for group in self.param_groups:
|
||
for p in group["params"]:
|
||
if p.ndim != 2:
|
||
raise ValueError(
|
||
f"Muon only supports 2D parameters whereas we found a parameter with size: {p.size()}"
|
||
)
|
||
|
||
def _init_group(
|
||
self,
|
||
group: MutableMapping,
|
||
params_with_grad: list[Tensor],
|
||
grads: list[Tensor],
|
||
muon_momentum_bufs: list[Tensor],
|
||
):
|
||
for p in group["params"]:
|
||
if p.grad is None:
|
||
continue
|
||
|
||
if torch.is_complex(p):
|
||
raise RuntimeError("Muon does not support complex parameters")
|
||
if p.grad.is_sparse:
|
||
raise RuntimeError("Muon does not support sparse gradients")
|
||
|
||
params_with_grad.append(p)
|
||
grads.append(p.grad)
|
||
|
||
state = self.state[p]
|
||
|
||
if "momentum_buffer" not in state:
|
||
state["momentum_buffer"] = torch.zeros_like(
|
||
p.grad, memory_format=torch.preserve_format
|
||
)
|
||
muon_momentum_bufs.append(state["momentum_buffer"])
|
||
|
||
return False # has_complex
|
||
|
||
@torch.no_grad()
|
||
def step(self, closure=None):
|
||
"""Performs a single optimization step."""
|
||
loss = None
|
||
if closure is not None:
|
||
with torch.enable_grad():
|
||
loss = closure()
|
||
|
||
for group in self.param_groups:
|
||
lr = group["lr"]
|
||
weight_decay = group["weight_decay"]
|
||
momentum = group["momentum"]
|
||
|
||
params_with_grad: list[Tensor] = []
|
||
grads: list[Tensor] = []
|
||
muon_momentum_bufs: list[Tensor] = []
|
||
|
||
has_complex = self._init_group(
|
||
group,
|
||
params_with_grad,
|
||
grads,
|
||
muon_momentum_bufs,
|
||
)
|
||
|
||
muon(
|
||
params_with_grad,
|
||
grads,
|
||
muon_momentum_bufs,
|
||
lr=lr,
|
||
weight_decay=weight_decay,
|
||
momentum=momentum,
|
||
nesterov=group["nesterov"],
|
||
ns_coefficients=group["ns_coefficients"],
|
||
eps=group["eps"],
|
||
ns_steps=group["ns_steps"],
|
||
adjust_lr_fn=group["adjust_lr_fn"],
|
||
has_complex=has_complex,
|
||
)
|
||
return loss
|
||
|
||
|
||
Muon.__doc__ = (
|
||
r"""Implements Muon algorithm.
|
||
|
||
.. math::
|
||
\begin{aligned}
|
||
&\rule{110mm}{0.4pt} \\
|
||
&\textbf{input} : \gamma \text{ (lr)},\ \lambda \text{ (weight decay)},\
|
||
\mu \text{ (momentum)},\ \textit{nesterov}\in\{True,False\},\\
|
||
&\hspace{13mm}(a,b,c)\ \text{ (NS coefficients)},\
|
||
\varepsilon \text{ (epsilon)},\ k \text{ (NS steps)},\
|
||
\theta_0 \text{ (params)},\ f(\theta) \text{ (objective)} \\
|
||
&\textbf{initialize} : B_0 \leftarrow 0 \text{ (momentum buffer)} \\[-1.ex]
|
||
&\rule{110mm}{0.4pt} \\
|
||
&\textbf{for}\ t=1\ \textbf{to}\ \ldots\ \textbf{do} \\[0.25ex]
|
||
&\hspace{5mm} g_t \leftarrow \nabla_{\theta} f_t(\theta_{t-1}) \\[0.25ex]
|
||
&\hspace{5mm} B_t \leftarrow \mu B_{t-1} + g_t \\[0.25ex]
|
||
&\hspace{5mm} \widetilde{B}_t \leftarrow
|
||
\begin{cases}
|
||
g_t + \mu B_t, & \text{if nesterov}=True \\
|
||
B_t, & \text{if nesterov}=False
|
||
\end{cases} \\[1.0ex]
|
||
&\hspace{5mm} O_t \leftarrow \mathrm{NS}^{(a,b,c)}_{k}\!\big(\widetilde{B}_t;\ \varepsilon\big) \\[0.5ex]
|
||
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma\,\lambda\,\theta_{t-1}
|
||
\quad\text{(decoupled weight decay)} \\[0.25ex]
|
||
|
||
&\hspace{5mm} \gamma \leftarrow \mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big) \\[0.25ex]
|
||
&\hspace{5mm} \theta_t \leftarrow \theta_t - \gamma\, O_t \\
|
||
&\rule{110mm}{0.4pt} \\[-1.ex]
|
||
&\mathbf{return}\ \theta_t \\[-1.ex]
|
||
&\rule{110mm}{0.4pt}s
|
||
\end{aligned}
|
||
|
||
Here, :math:`\mathrm{NS}^{(a,b,c)}_{k}(\cdot;\varepsilon)` denotes :math:`k` iterations of the
|
||
Newton–Schulz orthogonalization operator parameterized by coefficients :math:`(a,b,c)`
|
||
with numerical stabilization :math:`\varepsilon`.
|
||
|
||
The purpose for :math:`\mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big)`
|
||
is to make the orthogonalized update have a consistent :math:`RMS` across rectangular matrices.
|
||
|
||
Keller's original implementation scales the update by :math:`\sqrt{\max\!\left(1, \frac{A}{B}\right)}`,
|
||
where :math:`A` and :math:`B` are dimension of the matrix being optimized.
|
||
|
||
Moonshot's implementation also focuses on matching :math:`RMS` of AdamW. The adjustment is computed as:
|
||
:math:`\gamma \leftarrow {0.2}\gamma\,\sqrt{\max\!\left({A}, {B}\right)}`
|
||
The method is adopted from `Muon is Scalable for LLM Training`_. Research
|
||
results show that with this adjustment Muon can directly reuse the learning rate
|
||
and weight decay tuned for AdamW.
|
||
|
||
We provide two options for the learning rate adjustment: "original", which follows Keller's
|
||
implementation, and "match_rms_adamw", which refers to Moonshot's implementation. This gives users the
|
||
flexibility to choose between the two. If `adjust_lr_fn` is not specified, the default is "original".
|
||
|
||
For further details regarding the algorithm we refer to `Muon: An optimizer for hidden layers in neural networks`_
|
||
and `Muon is Scalable for LLM Training`_.
|
||
"""
|
||
+ rf"""
|
||
Args:
|
||
{_params_doc}. Note that Muon is an optimizer for 2D parameters of neural network hidden layers. Other
|
||
parameters, such as bias, and embedding, should be optimized by a standard method such as AdamW.
|
||
lr (float, Tensor, optional): learning rate (default: 1e-3).
|
||
weight_decay (float, optional): weight decay (L2 penalty). (default: 0.1)
|
||
momentum (float, optional): momentum factor (default: 0.95)
|
||
nesterov (bool, optional): enables Nesterov momentum. Only applicable
|
||
when momentum is non-zero
|
||
ns_coefficients (tuple of three floats, optional): coefficients \(a,b,c\) for the
|
||
Newton–Schulz orthogonalization polynomial (default: ({DEFAULT_A}, {DEFAULT_B}, {DEFAULT_C}))
|
||
eps (float, optional): term added to the denominator for numerical stability. (default: {EPS})
|
||
ns_steps (int, optional): number of Newton–Schulz iteration steps. (default: {DEFAULT_NS_STEPS})
|
||
adjust_lr_fn (str, optional): function to adjust learning rate. One of "original" and "match_rms_adamw".
|
||
If not specified, we will default to use "original". (default: None)
|
||
|
||
.. _Muon\: An optimizer for hidden layers in neural networks:
|
||
https://kellerjordan.github.io/posts/muon/
|
||
.. _Muon is Scalable for LLM Training:
|
||
https://arxiv.org/pdf/2502.16982
|
||
|
||
"""
|
||
)
|
||
|
||
|
||
def _single_tensor_muon(
|
||
params: list[Tensor],
|
||
grads: list[Tensor],
|
||
muon_momentum_bufs: list[Tensor],
|
||
*,
|
||
lr: float,
|
||
weight_decay: float,
|
||
momentum: float,
|
||
nesterov: bool,
|
||
ns_coefficients: tuple[float, float, float],
|
||
ns_steps: int,
|
||
eps: float,
|
||
adjust_lr_fn: Optional[str],
|
||
has_complex: bool,
|
||
) -> None:
|
||
lr = _to_scalar(lr)
|
||
if has_complex:
|
||
raise ValueError("Complex parameters are not supported")
|
||
|
||
for i, param in enumerate(params):
|
||
grad = grads[i]
|
||
if grad.ndim != 2:
|
||
raise ValueError("Param gradient must be a 2D matrix")
|
||
|
||
buf = muon_momentum_bufs[i]
|
||
buf.lerp_(grad, 1 - momentum)
|
||
update = grad.lerp(buf, momentum) if nesterov else buf
|
||
|
||
update = _zeropower_via_newtonschulz(update, ns_coefficients, ns_steps, eps)
|
||
|
||
adjusted_lr = _adjust_lr(lr, adjust_lr_fn, param.shape)
|
||
|
||
param.mul_(1 - lr * weight_decay)
|
||
param.add_(update, alpha=-adjusted_lr)
|
||
|
||
|
||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_muon)
|
||
def muon(
|
||
params: list[Tensor],
|
||
grads: list[Tensor],
|
||
muon_momentum_bufs: list[Tensor],
|
||
*,
|
||
foreach: Optional[bool] = None,
|
||
lr: float,
|
||
weight_decay: float,
|
||
momentum: float,
|
||
nesterov: bool,
|
||
ns_coefficients: tuple[float, float, float],
|
||
ns_steps: int,
|
||
eps: float,
|
||
adjust_lr_fn: Optional[str],
|
||
has_complex: bool,
|
||
):
|
||
r"""Functional API that performs Muon algorithm computation.
|
||
|
||
See :class:`~torch.optim.Muon` for details.
|
||
"""
|
||
if foreach is not None and foreach:
|
||
raise RuntimeError("Foreach is not supported for Muon yet")
|
||
|
||
func = _single_tensor_muon
|
||
|
||
func(
|
||
params,
|
||
grads,
|
||
muon_momentum_bufs,
|
||
lr=lr,
|
||
weight_decay=weight_decay,
|
||
momentum=momentum,
|
||
nesterov=nesterov,
|
||
ns_coefficients=ns_coefficients,
|
||
ns_steps=ns_steps,
|
||
eps=eps,
|
||
adjust_lr_fn=adjust_lr_fn,
|
||
has_complex=has_complex,
|
||
)
|