A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.
This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.
**Usage**
```python
model = MyModelForCausalLM
# filter out your params manually
muon_params = [...]
adamw_params = [...]
muon = Muon(
params = muon_params
lr=lr,
wd=wd,
)
adamw = AdamW(
params = adamw_params
lr=lr,
wd=wd,
)
# in training loop
loss = model(input)
loss.backward()
muon.step()
adamw.step()
muon.zero_grad()
adamw.zero_grad()
```
~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~
```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```
As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.
By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.
**Testing**
~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~
As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.
2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />
**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.
As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.
As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.
```python
# dummy model and data
model0 = Linear(10, 10, bias=False)
model1 = copy.deepcopy(model0)
inputs = torch.randn(8, 10)
targets = torch.randn(8, 10)
loss = MSELoss()
lr = 1e-3
wd = 0.1
momentum = 0.95
opt_ref_muon = KellySingleDeviceMuon(
params=model0.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
opt_exp_muon = Muon(
params=model1.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
out_ref = model0(inputs)
loss_ref = loss(out_ref, targets)
opt_ref_muon.zero_grad()
loss_ref.backward()
opt_ref_muon.step()
out_exp = model1(inputs)
loss_exp = loss(out_exp, targets)
opt_exp_muon.zero_grad()
loss_exp.backward()
opt_exp_muon.step()
for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
torch.testing.assert_close(p_ref, p_exp)
```
As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />
**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />
Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />
**Note**
We restrict the implementation to accept only 2D parameters.
An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.
Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.
**Next Steps**
1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.
****
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
Fixes an error message in test/test_optim.py
Current behavior: If running the test with Adagrad, the error message reads: "SGD does not currently support capturable".
Fix: The error message now says correctly: "Adagrad does not currently support capturable".
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153076
Approved by: https://github.com/janeyx99
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements
> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
> len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
> f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
Approved by: https://github.com/malfet
Thanks @clee2000 for bringing the memleak to my attention: https://github.com/pytorch/pytorch/actions/runs/12549765082/job/34996244798.
This memleak in the test was caused by the differentiable flavors. Because we had param.clone() and param persisted outside the for loop, the autograd graph would continue growing for each optimizer.step instead of being deleted after the optim input was used up.
To clarify, I had still expected (and still do expect) the test to fully clean everything up once the test is over, but I didn't get the chance to look into why that's not the case. This change would preliminarily unblock this particular test from failing the memleak CI.
Use detach instead of clone, which is...cheaper anyway :D since a detach I've learned from @soulitzer is a view with requires_grad=False
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144154
Approved by: https://github.com/clee2000, https://github.com/Skylion007, https://github.com/huydhn, https://github.com/albanD
Fixes#104899
Refactors AdamW into Adam by making AdamW a subclass of Adam. Additionally adds a test to assert that the added parameter `decoupled_weight_decay` is True in AdamW and also updates test_defaults_changed_to_foreach to account for the differences in module location for AdamW.
Heavily heavily inspired by #118857 by @tfsingh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143710
Approved by: https://github.com/janeyx99
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**
(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)
## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0,1]
'param_names' ['layer.weight', 'layer.bias'] (optional)
}
]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.
## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.
## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.
#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.
## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.
## Solution Example:
A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
# assuming a single param group.
current_state_group = optimizer.state_dict()['param_groups'][0]
loaded_state_group = state_dict['param_groups'][0]
# same number of params, same names, only different ordering
current_state_name_to_id_mapping = {} # mapping -- param_name: id
for i, name in enumerate(current_state_group['param_names']):
current_state_name_to_id_mapping[name] = current_state_group['params'][i]
# changing the ids of the loaded state dict to match the order of the given state dict.
for i, name in enumerate(current_state_group['param_names']):
loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]
return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.
### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.
Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132336
Approved by: https://github.com/albanD
ghstack dependencies: #133360
#109581
At this point, the vanilla implementation (the default) is good.
Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor
Specifically, the impl in this PR, which attempts to replicate the paper,
```
optim = torch.optim.Adafactor([weight])
```
is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor
```
optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False)
```
is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor
```
optim = keras.optimizers.Adafactor(learning_rate=0.01)
```
The three results respectively for the same randomly generated weights:
```
# ours
tensor([[ 0.3807594, -0.3912092],
[ 0.0762539, 0.5377805],
[ 0.2459473, 0.4662207]])
# pytorch-optimizer
tensor([[ 0.3807592, -0.3912172],
[ 0.0762507, 0.5377818],
[ 0.2459457, 0.4662213]])
# keras
array([[ 0.38076326, -0.39121315],
[ 0.0762547 , 0.5377859 ],
[ 0.24594972, 0.46622536]], dtype=float32)
```
This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences:
* keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01.
* We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling.
<details>
Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing
My script repro:
```
import torch
from pytorch_optimizer import AdaFactor
torch.set_printoptions(precision=7)
weight = torch.tensor([[ 0.37697506, -0.39500135],
[ 0.07246649, 0.53399765],
[ 0.24216151, 0.46243715]], dtype=torch.float32)
# bias = torch.tensor([0, 0], dtype=torch.float32)
weight.grad = torch.tensor([[-0.5940447, -0.7743838],
[-0.5940447, -0.7743838],
[-0.5940447, -0.7743838]], dtype=torch.float32)
# bias.grad = torch.tensor([-2.5027974, 1.5422692], dtype=torch.float32)
weight_c = weight.clone()
weight_c.grad = weight.grad.clone()
optim = torch.optim.Adafactor([weight])
optim.step()
print(weight)
optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False)
optim_c.step()
print(weight_c)
```
<details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905
Approved by: https://github.com/albanD
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Fixes some files in #123062
Run lintrunner on files:
test/test_nnapi.py,
test/test_numba_integration.py,
test/test_numpy_interop.py,
test/test_openmp.py,
test/test_optim.py
```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126845
Approved by: https://github.com/ezyang
> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor.
- [X] Ill assumption that every param will have the same step.
- [x] DIfferent implementation between `foreach=Ture` and `foreach=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440
Approved by: https://github.com/janeyx99
Enables LRScheduler to handle tensor LRs.
Note on test changes:
For the test modifications I just removed itertools.product and created two loops. This allows us to create a new set of optim_inputs on each iteration to prevent mutations on the tensor LR carrying over across iterations. Nothing else in those tests was modified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123753
Approved by: https://github.com/janeyx99
ghstack dependencies: #123751, #123752
- Original `test_grad_scaling_autocast_fused_optimizers` does not work since there is no "fused" in `optim_inputs`
- We should use different `grad_scaler`, they should not share 1 `scale`, there is no issue exposed here because the default `_growth_interval` is 2000 so it will not growth and there is also no inf is found so it will not reduced. The one in `test_cuda.py` should also have this issue,
- I set a manual seed to reproduce purpose if there is any numerical failure
- I use Tensor tracker here because we failed this UT in dynamo case, the cpp generated code are not exactly same with fused/non fused kernel.
- I make it check both `cuda` and `cpu`.
- I find some SGD numerical issue with `clang`, and fixed it by using `fmadd` instead of `add/mul` in fused sgd veckernel.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124904
Approved by: https://github.com/jgong5, https://github.com/janeyx99