1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.
Fixes#163103
LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```
This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
Fixes#163105
Note that the new `SWALR.load_state_dict` is **not backwards compatible**:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
self._set_anneal_func(self._anneal_strategy)
```
If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
anneal_func = state_dict.pop("anneal_func", None)
strategy = state_dict.get("_anneal_strategy")
self.__dict__.update(state_dict)
if anneal_func is not None:
state_dict["anneal_func"] = anneal_func
if strategy is None:
if anneal_func == self._linear_anneal:
strategy = "linear"
elif anneal_func == self._cosine_anneal:
strategy = "cos"
if strategy is None:
strategy = getattr(self, "_anneal_strategy", "cos")
self._set_anneal_func(strategy)
```
But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122
Approved by: https://github.com/janeyx99
Prevents edge cases in SequentialLR and ReduceLROnPlateau which could corrupt learning rates or trigger recompilation.
Supersedes #162360Fixes#162359Fixes#163093
While putting #162360 together, I noticed the class of issue I was fixing (i.e. unintended aliasing in lr_schedulers when using Tensor lrs) appeared in several other places. @janeyx99 suggested I put together a follow-up PR.
There are several bugs resembling the one fixed in #162360. I added a helper to fix these:
```python
def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor):
"""Set param_group[key] to val without aliasing or assignment when they're both tensors.
Raises a KeyError if param_group[key] does not exist.
"""
if isinstance(param_group[key], Tensor):
param_group[key].fill_(_to_scalar(val))
else:
param_group[key] = val
```
And applied it to fix bugs in `SequentialLR.__init__` and `LRScheduler._update_lr`. I also added it to `CyclicLR.__init__` which was using an equivalent pattern, and `CosineAnnealingWarmRestarts.step` which *should* have had a similar issue:
```python
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
```
But did not, because `get_lr()` actually returns tensors when using a tensor lr (despite its `list[float]` return type annotation). Relying on this propagation seems fragile, so I conservatively added the method here as well. I'll be fixing the type annotations and several related issues in followup PRs built off of this one.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163098
Approved by: https://github.com/janeyx99
Summary:
The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.
After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).
Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI
Rollback Plan:
Differential Revision: D78074709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158017
Approved by: https://github.com/janeyx99
A performance optimization. Using `torch.addmm`, which fuses `matrix multiply + scale + add` into one op.
**Benchmark**
In a QWEN-like 0.5B model training we observed average `optimizer.step()` latency speedup: matmul ~44.5 ms -> addmm ~27.4 ms: a **1.62×** speedup.
matmul
<img width="1403" height="600" alt="Screenshot 2025-08-24 at 3 15 37 PM" src="https://github.com/user-attachments/assets/a77a68d4-da3c-473a-97f0-e6ef0a3b46d9" />
addmm
<img width="1426" height="602" alt="Screenshot 2025-08-24 at 3 13 42 PM" src="https://github.com/user-attachments/assets/e493af36-44d3-4026-9f7c-fd0f9cdbc7e5" />
**Testing**
End-to-end training:
We used a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curves show consistency between normal matmul and addmm.
<img width="1035" height="434" alt="Screenshot 2025-08-24 at 2 56 21 PM" src="https://github.com/user-attachments/assets/b96b13e3-0a01-4908-853c-d917b41f3d75" />
Unit test:
```python
# dummy model and data
model0 = Linear(10, 10, bias=False)
model1 = copy.deepcopy(model0)
inputs = torch.randn(8, 10)
targets = torch.randn(8, 10)
loss = MSELoss()
lr = 1e-3
wd = 0.1
momentum = 0.95
opt_ref_muon = Muon(
params=model0.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
nesterov=nesterov,
adjust_lr_fn="original",
)
opt_exp_muon = Muon(
params=model1.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
nesterov=nesterov,
adjust_lr_fn="original",
use_addmm=True,
)
out_ref = model0(inputs)
loss_ref = loss(out_ref, targets)
opt_ref_muon.zero_grad()
loss_ref.backward()
opt_ref_muon.step()
out_exp = model1(inputs)
loss_exp = loss(out_exp, targets)
opt_exp_muon.zero_grad()
loss_exp.backward()
opt_exp_muon.step()
for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
torch.testing.assert_close(p_ref, p_exp)
```
shows numeric difference, but this is expected on bf16 precision:
```
Mismatched elements: 96 / 100 (96.0%)
Greatest absolute difference: 8.985400199890137e-05 at index (1, 9) (up to 1e-06 allowed)
Greatest relative difference: 0.007370449136942625 at index (0, 6) (up to 1e-05 allowed)
```
~~Introduced a flag that allows users to opt in, as there are numerical differences relative to the original implementation.~~
Update: since `addmm` fuses the math ops, there are fewer intermediate roundings and is therefore more numerically accurate compared to the original form. Based on this, we opt to make `addmm` the default and only option.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161379
Approved by: https://github.com/janeyx99
A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.
This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.
**Usage**
```python
model = MyModelForCausalLM
# filter out your params manually
muon_params = [...]
adamw_params = [...]
muon = Muon(
params = muon_params
lr=lr,
wd=wd,
)
adamw = AdamW(
params = adamw_params
lr=lr,
wd=wd,
)
# in training loop
loss = model(input)
loss.backward()
muon.step()
adamw.step()
muon.zero_grad()
adamw.zero_grad()
```
~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~
```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```
As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.
By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.
**Testing**
~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~
As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.
2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />
**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.
As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.
As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.
```python
# dummy model and data
model0 = Linear(10, 10, bias=False)
model1 = copy.deepcopy(model0)
inputs = torch.randn(8, 10)
targets = torch.randn(8, 10)
loss = MSELoss()
lr = 1e-3
wd = 0.1
momentum = 0.95
opt_ref_muon = KellySingleDeviceMuon(
params=model0.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
opt_exp_muon = Muon(
params=model1.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
out_ref = model0(inputs)
loss_ref = loss(out_ref, targets)
opt_ref_muon.zero_grad()
loss_ref.backward()
opt_ref_muon.step()
out_exp = model1(inputs)
loss_exp = loss(out_exp, targets)
opt_exp_muon.zero_grad()
loss_exp.backward()
opt_exp_muon.step()
for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
torch.testing.assert_close(p_ref, p_exp)
```
As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />
**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />
Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />
**Note**
We restrict the implementation to accept only 2D parameters.
An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.
Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.
**Next Steps**
1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.
****
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
Improves typing so that all the optimizer subclasses (which all of them that subtype step) do not erase their type signature when this decorator is used. Now *kwarg values and returns will propogate
This complements @tsunghsienlee PR #153367 as the type signature of step() was being erased on all the optimizer subclasses by this untyped decorator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153374
Approved by: https://github.com/janeyx99, https://github.com/tsunghsienlee
## Summary
This PR updates the docstring for `CosineAnnealingLR` to accurately reflect its recursive learning rate schedule. The previous docstring displayed only the SGDR closed-form expression, which doesn't match the actual recursive implementation in code.
Changes:
- Added the recursive update formula used in `get_lr()`
- Retained the original closed-form SGDR expression for reference
- Clarified that warm restarts are not implemented in this scheduler
This addresses confusion raised in issue #152081.
## Related issue
[#152081](https://github.com/pytorch/pytorch/issues/152081)
## Testing
Doc-only change. Ran pre-commit to verify formatting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152936
Approved by: https://github.com/janeyx99