A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.
This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.
**Usage**
```python
model = MyModelForCausalLM
# filter out your params manually
muon_params = [...]
adamw_params = [...]
muon = Muon(
params = muon_params
lr=lr,
wd=wd,
)
adamw = AdamW(
params = adamw_params
lr=lr,
wd=wd,
)
# in training loop
loss = model(input)
loss.backward()
muon.step()
adamw.step()
muon.zero_grad()
adamw.zero_grad()
```
~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~
```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```
As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.
By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.
**Testing**
~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~
As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.
2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />
**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.
As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.
As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.
```python
# dummy model and data
model0 = Linear(10, 10, bias=False)
model1 = copy.deepcopy(model0)
inputs = torch.randn(8, 10)
targets = torch.randn(8, 10)
loss = MSELoss()
lr = 1e-3
wd = 0.1
momentum = 0.95
opt_ref_muon = KellySingleDeviceMuon(
params=model0.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
opt_exp_muon = Muon(
params=model1.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
out_ref = model0(inputs)
loss_ref = loss(out_ref, targets)
opt_ref_muon.zero_grad()
loss_ref.backward()
opt_ref_muon.step()
out_exp = model1(inputs)
loss_exp = loss(out_exp, targets)
opt_exp_muon.zero_grad()
loss_exp.backward()
opt_exp_muon.step()
for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
torch.testing.assert_close(p_ref, p_exp)
```
As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />
**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />
Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />
**Note**
We restrict the implementation to accept only 2D parameters.
An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.
Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.
**Next Steps**
1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.
****
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
Our three main users are OK with this, with two of them (foreach_map,
invoke_quant) prefering it like this.
I was originally worried about BC issues (this now means you cannot add
any positional args) but I think that's not a concern -- one can always
add kwonly args.
Test Plan
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146730
Approved by: https://github.com/ydwu4, https://github.com/mlazos
# Motivation
Fix https://github.com/pytorch/pytorch/issues/138577.
# Solution
1. All UTs in `test/inductor/test_compiled_optimizers.py` are fixed by https://github.com/pytorch/pytorch/pull/134170
2. UT in `test/inductor/test_pattern_matcher.py` is introduced by https://github.com/pytorch/pytorch/pull/138089, we will skip this UT due to the unsupported feature `max_autotune_gemm_backends:Triton`.
3. We have a new impl related to `histc`, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py`
4. We support `avg_pool3d` for `fp16` data type, so we remove the expected failure from `test/inductor/test_torchinductor_opinfo.py`
5. CUDA-bias code is introduced by https://github.com/pytorch/pytorch/issues/138472, we just generalize it to `GPU_TYPE`.
# Additional Context
> Why update torch-xpu-ops commit pin here?
We have to update commit pin to avoid the build failure raised by the code change [C10_UNUSED](https://github.com/pytorch/pytorch/pull/138364).
> What does the feature of torch-xpu-ops update?
1. Add some foreach ops, like `unary ops` and `foreach_clamp_max` etc;
2. Add some maxpool ops forward and backward, like `averge_pool3d` and `max_pool3d`
3. Add some other ops, like `log_normal_`, `index_copy`, and `mode` etc;
4. fix build failure related to `C10_UNUSED`;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138548
Approved by: https://github.com/malfet, https://github.com/EikanWang
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps. After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports.
Test Plan:
```
fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS
```
Before:
```
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib. Some things to try:
```
Differential Revision: D62049222
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614
Approved by: https://github.com/oulgen, https://github.com/laithsakka
Resubmit of #128979
`WeakDep`s force readers to have completed before a mutation overwrites the
buffer, but we want to allow fusions to occur for inplace mutations where the
same index is read and written.
Currently this is achieved by:
1. Identifying the buffers used by the mutating op in its `dep_closure`
2. Not creating `WeakDep`s for buffers in the `dep_closure`
3. Fixing up any bad fusions that might occur by an extra check in `can_fuse_vertical`
So we are first over-agressive in removing `WeakDep`, then add an ad-hoc fixup.
This PR instead emits all `WeakDep`s and adds a `fusable_weak_dep` check to
`can_fuse_vertical` which selectively allows inplace operation to fuse.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130835
Approved by: https://github.com/lezcano
In https://www.internalfb.com/intern/sevmanager/view/s/429861/, a downstream consuming buffer `buf486_buf526` had two read dependencies; `buf373` and `buf394`, both of which were at separate indices of the upstream foreach op. `buf486_buf526` was fused into `buf373` because in the usual fused case, this is completely fine if all dependencies are met in the upstream fused buffer. However in the foreach case and this case specifically it is possible for foreach ops to be partitioned if there are many arguments in order to stay under CUDA driver arg limits. As a result, this large foreach op was split into two, and the latter had `buf394` in its node schedule for allocation, while the earlier split did not, even though `buf486_buf526` uses the `buf394`, as a result we would hit the unbound local error.
@eellison provided this repro to help debug the issue (https://www.internalfb.com/phabricator/paste/view/P1453035092)
To fix this, we no longer return a valid producer subnode if there are multiple producer subnodes for a downstream consuming op. In short we should not fuse if there are dependencies on multiple foreach subkernels because 1) their execution order is non-deterministic and 2) (this issue) we may not properly handle dependencies in the presence of foreach partitioning.
Co-authored-by: David Berard <dberard@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130046
Approved by: https://github.com/eellison
`WeakDep`s force readers to have completed before a mutation overwrites the
buffer, but we want to allow fusions to occur for inplace mutations where the
same index is read and written.
Currently this is achieved by:
1. Identifying the buffers used by the mutating op in its `dep_closure`
2. Not creating `WeakDep`s for buffers in the `dep_closure`
3. Fixing up any bad fusions that might occur by an extra check in `can_fuse_vertical`
So we are first over-agressive in removing `WeakDep`, then add an ad-hoc fixup.
This PR instead emits all `WeakDep`s and adds a `fusable_weak_dep` check to
`can_fuse_vertical` which selectively allows inplace operation to fuse.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128979
Approved by: https://github.com/lezcano
ghstack dependencies: #129082, #129083
This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure.
Lintrunner passed:
```
$ lintrunner test/test_cuda.py
ok No lint issues.
```
Tests passed:
```
>python test_cuda.py -k test_graph_optims
Ran 19 tests in 7.463s
OK (skipped=9)
>python test_cuda.py -k test_graph_scaling_fused_optimizers
Ran 6 tests in 2.800s
OK (skipped=3)
```
Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers.
I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called:
```
{'foreach': False, 'maximize': True, 'weight_decay': 0}
{ 'foreach': True, 'maximize': True, 'weight_decay': 0}
```
I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned:
```
{'amsgrad': True}
```
Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127
Approved by: https://github.com/janeyx99
Removed a bunch of skips, I also updated test_forloop_goes_right_direction to *not* use the closure when dynamo is tracing. The reason for this is that testing the disabled optimizer doesn't actually test anything.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123322
Approved by: https://github.com/janeyx99
ghstack dependencies: #123498