[muon] Introduce Muon optimizer to PyTorch (#160213)

A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.

This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.

**Usage**

```python
    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = adamw_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()
```

~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~

```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```

As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.

By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.

**Testing**

~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~

As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.

2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />

**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.

As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.

As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = KellySingleDeviceMuon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />

**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />

**Note**
We restrict the implementation to accept only 2D parameters.

An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.

Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.

**Next Steps**

1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.

****

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
This commit is contained in:
Chuanhao Zhuge
2025-08-24 08:03:04 +00:00
committed by PyTorch MergeBot
parent 1de4540449
commit 74280d0913
6 changed files with 591 additions and 50 deletions

View File

@ -165,6 +165,7 @@ for input, target in dataset:
Adamax
ASGD
LBFGS
Muon
NAdam
RAdam
RMSprop
@ -210,6 +211,7 @@ Below is a table showing the available and default implementations of each algor
:class:`Adamax`;foreach;yes;no
:class:`ASGD`;foreach;yes;no
:class:`LBFGS`;for-loop;no;no
:class:`Muon`;for-loop;no;no
:class:`NAdam`;foreach;yes;no
:class:`RAdam`;foreach;yes;no
:class:`RMSprop`;foreach;yes;no
@ -233,6 +235,7 @@ Below table is showing the stability status for fused implementations:
:class:`Adamax`;unsupported;unsupported;unsupported
:class:`ASGD`;unsupported;unsupported;unsupported
:class:`LBFGS`;unsupported;unsupported;unsupported
:class:`Muon`;unsupported;unsupported;unsupported
:class:`NAdam`;unsupported;unsupported;unsupported
:class:`RAdam`;unsupported;unsupported;unsupported
:class:`RMSprop`;unsupported;unsupported;unsupported

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import random
import sys
import types
import unittest
@ -583,6 +584,9 @@ class CompiledOptimizerParityTests(TestCase):
@optims(optim_db, dtypes=[torch.float32])
@parametrize("use_closure", [True, False])
def test_correctness(self, device, dtype, optim_info, use_closure):
torch.cuda.manual_seed_all(0)
torch.manual_seed(0)
random.seed(0)
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
@ -604,7 +608,10 @@ class CompiledOptimizerParityTests(TestCase):
torch._inductor.metrics.reset()
input = torch.ones([10, 10], device=device)
model_eager = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device=device) for _ in range(2)]
*[
torch.nn.Linear(10, 10, device=device, bias=False)
for _ in range(2)
]
)
model_eager(input).sum().backward()
model_compiled = deepcopy(model_eager)

View File

@ -187,7 +187,8 @@ class TestOptimRenewed(TestCase):
)
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]
optimizer = optim_cls(params, **optim_input.kwargs)
schedulers = [
s(optimizer)
for s in (schedulers_constructor if schedulers_constructor else [])
@ -195,7 +196,12 @@ class TestOptimRenewed(TestCase):
def closure():
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
wo = (
weight.mv(input)
if optim_cls.__name__ == "Muon"
else weight.mv(input) + bias
)
loss = wo.pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
@ -246,7 +252,8 @@ class TestOptimRenewed(TestCase):
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
inpt = torch.randn(5, device="cuda:0", dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]
optimizer = optim_cls(params, **optim_input.kwargs)
schedulers = [
s(optimizer)
for s in (schedulers_constructor if schedulers_constructor else [])
@ -254,7 +261,12 @@ class TestOptimRenewed(TestCase):
def closure():
optimizer.zero_grad()
loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
wo = (
weight.mv(inpt).cuda(1)
if optim_cls.__name__ == "Muon"
else weight.mv(inpt).cuda(1) + bias
)
loss = wo.pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
@ -285,23 +297,25 @@ class TestOptimRenewed(TestCase):
for schedulers_c in optim_info.scheduler_inputs:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
weight2 = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
inpt = torch.randn(5, device=device, dtype=dtype)
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
optimizer = optim_cls(
[{"params": [weight]}, {"params": [weight2], "lr": lr}]
)
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
def closure():
optimizer.zero_grad()
loss = (weight.mv(inpt) + bias).pow(2).sum()
loss = (weight.mv(inpt) + weight2.mv(inpt)).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
weight2.grad = weight2.grad.to_sparse()
return loss
initial_value = closure().item()
@ -339,21 +353,26 @@ class TestOptimRenewed(TestCase):
if "lr" in kwargs:
del kwargs["lr"]
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]
kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
optimizer_r = optim_cls([weight, bias], **kwargs)
optimizer_r = optim_cls(params, **kwargs)
try:
kwargs["lr"] = (
torch.tensor(kwargs["lr"]).reshape([1] * num_dim).to(lr_device)
)
optimizer = optim_cls([weight_c, bias_c], **kwargs)
params_c = [weight_c, bias_c]
if optim_cls.__name__ == "Muon":
params_c = [weight_c]
optimizer = optim_cls(params_c, **kwargs)
except ValueError as e:
self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
continue
def closure(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b
loss = wo.pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
@ -377,7 +396,8 @@ class TestOptimRenewed(TestCase):
optimizer.step()
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
if optim_cls.__name__ != "Muon":
self.assertEqual(bias, bias_c)
@parametrize("with_lrsched", [True, False])
@optims(
@ -1217,31 +1237,31 @@ class TestOptimRenewed(TestCase):
)
for optim_input in all_optim_inputs:
weight_kwargs = optim_input.kwargs
bias_kwargs = deepcopy(optim_input.kwargs)
bias_kwargs["weight_decay"] = 0.0
weight2_kwargs = deepcopy(optim_input.kwargs)
weight2_kwargs["weight_decay"] = 0.0
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
weight2 = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls(
[
dict(params=[weight], **weight_kwargs),
dict(params=[bias], **bias_kwargs),
dict(params=[weight2], **weight2_kwargs),
]
)
loss = (weight.mv(input) + bias).pow(2).sum()
loss = (weight.mv(input) + weight2.mv(input)).pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss = (weight.mv(input) + weight2.mv(input)).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
weight2.grad = weight2.grad.to_sparse()
optimizer.step()
# Test that the direction of loss moved appropriately
@ -1268,22 +1288,33 @@ class TestOptimRenewed(TestCase):
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
irrelevant = Parameter(torch.randn((2, 2), device=device, dtype=dtype))
irrelevant_clone = irrelevant.clone()
input = torch.randn(5, device=device, dtype=dtype)
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]
optimizer = optim_cls(
[
dict(params=[weight, bias], **optim_input.kwargs),
dict(params=params, **optim_input.kwargs),
dict(params=[irrelevant]),
],
**outer_kwargs,
)
loss = (weight.mv(input) + bias).pow(2).sum()
wo = (
weight.mv(input)
if optim_cls.__name__ == "Muon"
else weight.mv(input) + bias
)
loss = wo.pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
wo = (
weight.mv(input)
if optim_cls.__name__ == "Muon"
else weight.mv(input) + bias
)
loss = wo.pow(2).sum()
loss.backward()
irrelevant.grad = torch.rand_like(irrelevant)
if optim_info.only_supports_sparse_grads:
@ -1341,8 +1372,8 @@ class TestOptimRenewed(TestCase):
if kwargs.get("weight_decay", 0) != 0:
continue
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
if optim_cls.__name__ == "AdamW":
# AdamW/Muon params will be updated regardless of grads due to lr, so make lr smaller
if optim_cls.__name__ == "AdamW" or optim_cls.__name__ == "Muon":
kwargs["lr"] = (
torch.tensor(1e-5)
if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
@ -1439,6 +1470,8 @@ class TestOptimRenewed(TestCase):
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
params = [weight, bias]
if optim_cls.__name__ == "Muon":
params = [weight]
def make_named_param(param, is_named):
if not is_named:
@ -1453,7 +1486,8 @@ class TestOptimRenewed(TestCase):
def fwd_bwd(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b
loss = wo.pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
if w.grad is not None:
@ -1479,7 +1513,10 @@ class TestOptimRenewed(TestCase):
with torch.no_grad():
weight_c = Parameter(weight.clone())
bias_c = Parameter(bias.clone())
params_c = make_named_param([weight_c, bias_c], is_named=is_named_optim1)
params_c_list = (
[weight_c, bias_c] if optim_cls.__name__ != "Muon" else [weight_c]
)
params_c = make_named_param(params_c_list, is_named=is_named_optim1)
optimizer_c = optim_cls(params_c, **optim_input.kwargs)
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
@ -1498,7 +1535,8 @@ class TestOptimRenewed(TestCase):
optimizer_c.step()
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
if optim_cls.__name__ != "Muon":
self.assertEqual(bias, bias_c)
# Make sure state dict is deterministic with equal (not identical) parameters
# Param names are optional and not needed to be the consistent.
@ -1522,14 +1560,24 @@ class TestOptimRenewed(TestCase):
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
def _get_model_and_input_tensor(device, dtype, optim_cls):
if optim_cls.__name__ == "Muon":
# Muon only accepts 2D parameter.
model = torch.nn.Linear(10, 4, bias=False)
input = torch.rand(10, device=device, dtype=dtype)
else:
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
model.to(dtype=dtype, device=device)
return model, input
for optim_input in all_optim_inputs:
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
model.to(dtype=dtype, device=device)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
model, input = _get_model_and_input_tensor(device, dtype, optim_cls)
optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
def fwd_bwd(optim, mod, i):
@ -1577,14 +1625,24 @@ class TestOptimRenewed(TestCase):
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
def _get_model_and_input_tensor(device, dtype, optim_cls):
if optim_cls.__name__ == "Muon":
# Muon only accepts 2D parameter.
model = torch.nn.Linear(10, 4, bias=False)
input = torch.rand(10, device=device, dtype=dtype)
else:
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
model.to(dtype=dtype, device=device)
return model, input
for optim_input in all_optim_inputs:
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
model.to(dtype=dtype, device=device)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
model, input = _get_model_and_input_tensor(device, dtype, optim_cls)
def fwd_bwd(optim, mod, i):
optim.zero_grad()
@ -1621,11 +1679,12 @@ class TestOptimRenewed(TestCase):
fwd_bwd(optimizer2, model, input)
optimizer2.step()
ref_names = [p[0] for p in model.named_parameters()]
# Make sure that param_names are preserved when provided to at least one of the optimizers
if is_named_optim0 or is_named_optim1:
self.assertEqual(
optimizer2.state_dict()["param_groups"][0]["param_names"],
["0.weight", "0.bias", "1.weight", "1.bias"],
ref_names,
)
@parametrize("is_named_optim", [True, False])
@ -1644,7 +1703,7 @@ class TestOptimRenewed(TestCase):
)
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
params = [weight, bias]
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]
def make_named_param(param, is_named):
if not is_named:
@ -1653,7 +1712,8 @@ class TestOptimRenewed(TestCase):
def fwd_bwd(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
wo = w.mv(i) if optim_cls.__name__ == "Muon" else w.mv(i) + b
loss = wo.pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
weight.grad = weight.grad.to_sparse()
@ -1937,7 +1997,7 @@ class TestOptimRenewed(TestCase):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
params = [torch.tensor([[1, 1]], device=device, dtype=dtype)]
def dummy_closure():
return 1
@ -1969,7 +2029,8 @@ class TestOptimRenewed(TestCase):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
# Create a random 2D tensor for compatibility with Muon.
params = [torch.tensor([[1, 1]], device=device, dtype=dtype)]
def dummy_closure():
return 1
@ -2013,7 +2074,7 @@ class TestOptimRenewed(TestCase):
nonlocal data
data.append(2)
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
params = [torch.tensor([[1, 1]], device=device, dtype=dtype)]
def dummy_closure():
return 1
@ -2219,7 +2280,8 @@ class TestOptimRenewed(TestCase):
def test_non_empty_state(self, device, dtype, optim_info):
# There are internal tests that check that the state is not empty
optim_cls = optim_info.optim_cls
model = torch.nn.Linear(5, 5)
# Muon only accepts 2D parameter.
model = torch.nn.Linear(5, 5, bias=False)
model.to(dtype=dtype, device=device)
inpt = torch.rand(2, 5, dtype=dtype, device=device)

View File

@ -8,6 +8,7 @@ future.
from torch.optim import lr_scheduler as lr_scheduler, swa_utils as swa_utils
from torch.optim._adafactor import Adafactor as Adafactor
from torch.optim._muon import Muon as Muon
from torch.optim.adadelta import Adadelta as Adadelta
from torch.optim.adagrad import Adagrad as Adagrad
from torch.optim.adam import Adam as Adam
@ -25,6 +26,7 @@ from torch.optim.sparse_adam import SparseAdam as SparseAdam
Adafactor.__module__ = "torch.optim"
Muon.__module__ = "torch.optim"
del adadelta # type: ignore[name-defined] # noqa: F821
@ -52,6 +54,7 @@ __all__ = [
"ASGD",
"LBFGS",
"lr_scheduler",
"Muon",
"NAdam",
"Optimizer",
"RAdam",

360
torch/optim/_muon.py Normal file
View File

@ -0,0 +1,360 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
"""Implementation of the Muon optimizer."""
import math
from collections.abc import MutableMapping
from typing import Optional
import torch
from torch import Tensor
from .optimizer import (
_disable_dynamo_if_unsupported,
_params_doc,
_to_scalar,
Optimizer,
ParamsT,
)
__all__ = ["Muon"]
# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
# github permlink: https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L16
EPS = 1e-7
DEFAULT_A = 3.4445
DEFAULT_B = -4.7750
DEFAULT_C = 2.0315
DEFAULT_NS_STEPS = 5
def _zeropower_via_newtonschulz(
grad: Tensor, ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float
) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
Implementation reference: https://github.com/KellerJordan/Muon/blob/master/muon.py
with suggestions by @jxbz, @leloykun, and @YouJiacheng.
"""
if ns_steps >= 100:
raise ValueError(
"Number of steps must be less than 100 for computational efficiency"
)
if len(grad.shape) != 2:
raise ValueError("Input tensor gradient must be a 2D matrix")
if len(ns_coefficients) != 3:
raise ValueError("Coefficients must be a tuple of exactly 3 values")
a, b, c = ns_coefficients
ortho_grad = grad.bfloat16()
if grad.size(0) > grad.size(1):
ortho_grad = ortho_grad.T
# Ensure spectral norm is at most 1
ortho_grad.div_(ortho_grad.norm().clamp(min=eps))
# Perform the NS iterations
for _ in range(ns_steps):
gram_matrix = ortho_grad @ ortho_grad.T
gram_update = b * gram_matrix + c * gram_matrix @ gram_matrix
ortho_grad = a * ortho_grad + gram_update @ ortho_grad
if grad.size(0) > grad.size(1):
ortho_grad = ortho_grad.T
return ortho_grad
def _adjust_lr(
lr: float, adjust_lr_fn: Optional[str], param_shape: torch.Size
) -> float:
"""Default learning rate adjustment used by Muon."""
A, B = param_shape[:2]
if adjust_lr_fn is None or adjust_lr_fn == "original":
adjusted_ratio = math.sqrt(max(1, A / B))
elif adjust_lr_fn == "match_rms_adamw":
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
else:
adjusted_ratio = 1.0
return lr * adjusted_ratio
class Muon(Optimizer):
def __init__(
self,
params: ParamsT,
lr: float = 1e-3,
weight_decay: float = 0.1,
momentum: float = 0.95,
nesterov: bool = True,
ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C),
eps: float = EPS,
ns_steps: int = DEFAULT_NS_STEPS,
adjust_lr_fn: Optional[str] = None,
) -> None:
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 <= lr:
raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
if not 0.0 <= momentum:
raise ValueError(f"momentum should be >= 0 but is: {momentum}")
if not 0.0 <= weight_decay:
raise ValueError(f"weight decay should be >= 0 but is: {weight_decay}")
if adjust_lr_fn is not None and adjust_lr_fn not in [
"original",
"match_rms_adamw",
]:
raise ValueError(
f"Adjust learning rate function {adjust_lr_fn} is not supported"
)
defaults = {
"lr": lr,
"weight_decay": weight_decay,
"momentum": momentum,
"nesterov": nesterov,
"ns_coefficients": ns_coefficients,
"eps": eps,
"ns_steps": ns_steps,
"adjust_lr_fn": adjust_lr_fn,
}
super().__init__(params, defaults)
for group in self.param_groups:
for p in group["params"]:
if p.ndim != 2:
raise ValueError(
f"Muon only supports 2D parameters whereas we found a parameter with size: {p.size()}"
)
def _init_group(
self,
group: MutableMapping,
params_with_grad: list[Tensor],
grads: list[Tensor],
muon_momentum_bufs: list[Tensor],
):
for p in group["params"]:
if p.grad is None:
continue
if torch.is_complex(p):
raise RuntimeError("Muon does not support complex parameters")
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
muon_momentum_bufs.append(state["momentum_buffer"])
return False # has_complex
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
weight_decay = group["weight_decay"]
momentum = group["momentum"]
params_with_grad: list[Tensor] = []
grads: list[Tensor] = []
muon_momentum_bufs: list[Tensor] = []
has_complex = self._init_group(
group,
params_with_grad,
grads,
muon_momentum_bufs,
)
muon(
params_with_grad,
grads,
muon_momentum_bufs,
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=group["nesterov"],
ns_coefficients=group["ns_coefficients"],
eps=group["eps"],
ns_steps=group["ns_steps"],
adjust_lr_fn=group["adjust_lr_fn"],
has_complex=has_complex,
)
return loss
Muon.__doc__ = (
r"""Implements Muon algorithm.
.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)},\ \lambda \text{ (weight decay)},\
\mu \text{ (momentum)},\ \textit{nesterov}\in\{True,False\},\\
&\hspace{13mm}(a,b,c)\ \text{ (NS coefficients)},\
\varepsilon \text{ (epsilon)},\ k \text{ (NS steps)},\
\theta_0 \text{ (params)},\ f(\theta) \text{ (objective)} \\
&\textbf{initialize} : B_0 \leftarrow 0 \text{ (momentum buffer)} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for}\ t=1\ \textbf{to}\ \ldots\ \textbf{do} \\[0.25ex]
&\hspace{5mm} g_t \leftarrow \nabla_{\theta} f_t(\theta_{t-1}) \\[0.25ex]
&\hspace{5mm} B_t \leftarrow \mu B_{t-1} + g_t \\[0.25ex]
&\hspace{5mm} \widetilde{B}_t \leftarrow
\begin{cases}
g_t + \mu B_t, & \text{if nesterov}=True \\
B_t, & \text{if nesterov}=False
\end{cases} \\[1.0ex]
&\hspace{5mm} O_t \leftarrow \mathrm{NS}^{(a,b,c)}_{k}\!\big(\widetilde{B}_t;\ \varepsilon\big) \\[0.5ex]
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma\,\lambda\,\theta_{t-1}
\quad\text{(decoupled weight decay)} \\[0.25ex]
&\hspace{5mm} \gamma \leftarrow \mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big) \\[0.25ex]
&\hspace{5mm} \theta_t \leftarrow \theta_t - \gamma\, O_t \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\mathbf{return}\ \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt}s
\end{aligned}
Here, :math:`\mathrm{NS}^{(a,b,c)}_{k}(\cdot;\varepsilon)` denotes :math:`k` iterations of the
NewtonSchulz orthogonalization operator parameterized by coefficients :math:`(a,b,c)`
with numerical stabilization :math:`\varepsilon`.
The purpose for :math:`\mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big)`
is to make the orthogonalized update have a consistent :math:`RMS` across rectangular matrices.
Keller's original implementation scales the update by :math:`\sqrt{\max\!\left(1, \frac{A}{B}\right)}`,
where :math:`A` and :math:`B` are dimension of the matrix being optimized.
Moonshot's implementation also focuses on matching :math:`RMS` of AdamW. The adjustment is computed as:
:math:`\gamma \leftarrow {0.2}\gamma\,\sqrt{\max\!\left({A}, {B}\right)}`
The method is adopted from `Muon is Scalable for LLM Training`_. Research
results show that with this adjustment Muon can directly reuse the learning rate
and weight decay tuned for AdamW.
We provide two options for the learning rate adjustment: "original", which follows Keller's
implementation, and "match_rms_adamw", which refers to Moonshot's implementation. This gives users the
flexibility to choose between the two. If `adjust_lr_fn` is not specified, the default is "original".
For further details regarding the algorithm we refer to `Muon: An optimizer for hidden layers in neural networks`_
and `Muon is Scalable for LLM Training`_.
"""
+ rf"""
Args:
{_params_doc}. Note that Muon is an optimizer for 2D parameters of neural network hidden layers. Other
parameters, such as bias, and embedding, should be optimized by a standard method such as AdamW.
lr (float, Tensor, optional): learning rate (default: 1e-3).
weight_decay (float, optional): weight decay (L2 penalty). (default: 0.1)
momentum (float, optional): momentum factor (default: 0.95)
nesterov (bool, optional): enables Nesterov momentum. Only applicable
when momentum is non-zero
ns_coefficients (tuple of three floats, optional): coefficients \(a,b,c\) for the
NewtonSchulz orthogonalization polynomial (default: ({DEFAULT_A}, {DEFAULT_B}, {DEFAULT_C}))
eps (float, optional): term added to the denominator for numerical stability. (default: {EPS})
ns_steps (int, optional): number of NewtonSchulz iteration steps. (default: {DEFAULT_NS_STEPS})
adjust_lr_fn (str, optional): function to adjust learning rate. One of "original" and "match_rms_adamw".
If not specified, we will default to use "original". (default: None)
.. _Muon\: An optimizer for hidden layers in neural networks:
https://kellerjordan.github.io/posts/muon/
.. _Muon is Scalable for LLM Training:
https://arxiv.org/pdf/2502.16982
"""
)
def _single_tensor_muon(
params: list[Tensor],
grads: list[Tensor],
muon_momentum_bufs: list[Tensor],
*,
lr: float,
weight_decay: float,
momentum: float,
nesterov: bool,
ns_coefficients: tuple[float, float, float],
ns_steps: int,
eps: float,
adjust_lr_fn: Optional[str],
has_complex: bool,
) -> None:
lr = _to_scalar(lr)
if has_complex:
raise ValueError("Complex parameters are not supported")
for i, param in enumerate(params):
grad = grads[i]
if grad.ndim != 2:
raise ValueError("Param gradient must be a 2D matrix")
buf = muon_momentum_bufs[i]
buf.lerp_(grad, 1 - momentum)
update = grad.lerp(buf, momentum) if nesterov else buf
update = _zeropower_via_newtonschulz(update, ns_coefficients, ns_steps, eps)
adjusted_lr = _adjust_lr(lr, adjust_lr_fn, param.shape)
param.mul_(1 - lr * weight_decay)
param.add_(update, alpha=-adjusted_lr)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_muon)
def muon(
params: list[Tensor],
grads: list[Tensor],
muon_momentum_bufs: list[Tensor],
*,
foreach: Optional[bool] = None,
lr: float,
weight_decay: float,
momentum: float,
nesterov: bool,
ns_coefficients: tuple[float, float, float],
ns_steps: int,
eps: float,
adjust_lr_fn: Optional[str],
has_complex: bool,
):
r"""Functional API that performs Muon algorithm computation.
See :class:`~torch.optim.Muon` for details.
"""
if foreach is not None and foreach:
raise RuntimeError("Foreach is not supported for Muon yet")
func = _single_tensor_muon
func(
params,
grads,
muon_momentum_bufs,
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
nesterov=nesterov,
ns_coefficients=ns_coefficients,
ns_steps=ns_steps,
eps=eps,
adjust_lr_fn=adjust_lr_fn,
has_complex=has_complex,
)

View File

@ -20,6 +20,7 @@ from torch.optim import (
AdamW,
ASGD,
LBFGS,
Muon,
NAdam,
Optimizer,
RAdam,
@ -245,8 +246,9 @@ class optims(_TestParametrizer):
# Helper function for generating error inputs for all optimizers, used below.
def get_error_inputs_for_all_optims(device, dtype):
if _get_device_type(device) == "cpu":
sample_param = Parameter(torch.randn(1, device=device, dtype=dtype))
sample_param2 = Parameter(torch.randn(1, device=device, dtype=dtype))
# Creating 2D parameters for compatibility with Muon.
sample_param = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
sample_param2 = Parameter(torch.randn(1, 1, device=device, dtype=dtype))
return [
ErrorOptimizerInput(
OptimizerInput(
@ -833,6 +835,81 @@ def optim_error_inputs_func_lbfgs(device, dtype):
return error_inputs
def optim_inputs_func_muon(device, dtype=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"lr": torch.tensor(0.001)}, desc="Tensor lr"
),
OptimizerInput(
params=None,
kwargs={"weight_decay": 0.2},
desc="non-default weight_decay",
),
OptimizerInput(
params=None,
kwargs={"momentum": 0.8},
desc="non-default momentum",
),
OptimizerInput(
params=None,
kwargs={"ns_steps": 6},
desc="passing alternative ns_steps",
),
OptimizerInput(
params=None,
kwargs={
"ns_coefficients": (3.4, -4.7, 2.0),
},
desc="passing alternative ns_coefficients",
),
]
def optim_error_inputs_func_muon(device, dtype):
error_inputs = get_error_inputs_for_all_optims(device, dtype)
complex_param = torch.rand(2, 3, device=device, dtype=torch.complex64)
complex_param.grad = torch.rand_like(complex_param)
non_2d_param = torch.rand(2, 3, 4, device=device, dtype=dtype)
non_2d_param.grad = torch.rand_like(non_2d_param)
param = torch.rand(2, 3, device=device, dtype=dtype)
param.grad = torch.rand_like(param)
error_inputs += [
ErrorOptimizerInput(
OptimizerInput(
params=[non_2d_param],
kwargs=dict(),
desc="only support 2D parameters",
),
error_type=ValueError,
error_regex="Muon only supports 2D parameters",
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
),
ErrorOptimizerInput(
OptimizerInput(
params=[param],
kwargs={"adjust_lr_fn": "arbitrary"},
desc="only support `original` and `match_rms_adamw`",
),
error_type=ValueError,
error_regex="Adjust learning rate function arbitrary is not supported",
error_on=OptimizerErrorEnum.CONSTRUCTION_ERROR,
),
ErrorOptimizerInput(
OptimizerInput(
params=[complex_param],
kwargs=dict(),
desc="does not support complex parameters",
),
error_type=RuntimeError,
error_regex="Muon does not support complex parameters",
error_on=OptimizerErrorEnum.STEP_ERROR,
),
]
return error_inputs
def optim_inputs_func_nadam(device, dtype=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
@ -1869,6 +1946,35 @@ optim_db: list[OptimizerInfo] = [
),
),
),
OptimizerInfo(
Muon,
optim_inputs_func=optim_inputs_func_muon,
optim_error_inputs_func=optim_error_inputs_func_muon,
supported_impls=(),
not_og_supported_flags=(),
supports_complex=False,
skips=(
# Note on tolerances:
# test_correctness_Muon_use_closure_True_cuda_float32
# Mismatched elements: 2 / 100 (2.0%)
# Greatest absolute difference: 0.0006124898791313171 at index (2, 1) (up to 0.0002 allowed)
# Greatest relative difference: 0.026825083419680595 at index (2, 6) (up to 0.01 allowed)
# This is due compile uses addmm for matmul in the orthogonalization function,
# creating a small numerical difference compared to the plain matmul op used in eager.
DecorateInfo(
toleranceOverride(
{
torch.float: tol(
rtol=0.08,
atol=0.001,
),
}
),
"CompiledOptimizerParityTests",
"test_correctness",
),
),
),
OptimizerInfo(
NAdam,
optim_inputs_func=optim_inputs_func_nadam,