A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.
This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.
**Usage**
```python
model = MyModelForCausalLM
# filter out your params manually
muon_params = [...]
adamw_params = [...]
muon = Muon(
params = muon_params
lr=lr,
wd=wd,
)
adamw = AdamW(
params = adamw_params
lr=lr,
wd=wd,
)
# in training loop
loss = model(input)
loss.backward()
muon.step()
adamw.step()
muon.zero_grad()
adamw.zero_grad()
```
~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~
```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```
As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.
By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.
**Testing**
~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~
As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.
2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />
**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.
As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.
As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.
```python
# dummy model and data
model0 = Linear(10, 10, bias=False)
model1 = copy.deepcopy(model0)
inputs = torch.randn(8, 10)
targets = torch.randn(8, 10)
loss = MSELoss()
lr = 1e-3
wd = 0.1
momentum = 0.95
opt_ref_muon = KellySingleDeviceMuon(
params=model0.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
opt_exp_muon = Muon(
params=model1.parameters(),
lr=lr,
weight_decay=wd,
momentum=momentum,
)
out_ref = model0(inputs)
loss_ref = loss(out_ref, targets)
opt_ref_muon.zero_grad()
loss_ref.backward()
opt_ref_muon.step()
out_exp = model1(inputs)
loss_exp = loss(out_exp, targets)
opt_exp_muon.zero_grad()
loss_exp.backward()
opt_exp_muon.step()
for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
torch.testing.assert_close(p_ref, p_exp)
```
As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />
**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />
Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />
**Note**
We restrict the implementation to accept only 2D parameters.
An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.
Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.
**Next Steps**
1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.
****
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
This PR fixes a bug in `test_correct_module_names` introduced in #130497. It also addresses post-fix test failures in:
* `torch/ao/quantization/__init__.py` - set the correct `__module__` for several public API helpers
* `torch/library.py` - add `register_vmap` to `__all__`
* `torch/nn/attention/flex_attention.py` - make `round_up_to_multiple` private by prepending an underscore
* `torch/storage.py` - introduce `__all__` to avoid `Self` being re-exported as a public API
* `torch/distributed/pipelining/schedules.py` - add `ZeroBubbleAlgorithm` to `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131386
Approved by: https://github.com/albanD
#109581
At this point, the vanilla implementation (the default) is good.
Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor
Specifically, the impl in this PR, which attempts to replicate the paper,
```
optim = torch.optim.Adafactor([weight])
```
is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor
```
optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False)
```
is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor
```
optim = keras.optimizers.Adafactor(learning_rate=0.01)
```
The three results respectively for the same randomly generated weights:
```
# ours
tensor([[ 0.3807594, -0.3912092],
[ 0.0762539, 0.5377805],
[ 0.2459473, 0.4662207]])
# pytorch-optimizer
tensor([[ 0.3807592, -0.3912172],
[ 0.0762507, 0.5377818],
[ 0.2459457, 0.4662213]])
# keras
array([[ 0.38076326, -0.39121315],
[ 0.0762547 , 0.5377859 ],
[ 0.24594972, 0.46622536]], dtype=float32)
```
This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences:
* keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01.
* We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling.
<details>
Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing
My script repro:
```
import torch
from pytorch_optimizer import AdaFactor
torch.set_printoptions(precision=7)
weight = torch.tensor([[ 0.37697506, -0.39500135],
[ 0.07246649, 0.53399765],
[ 0.24216151, 0.46243715]], dtype=torch.float32)
# bias = torch.tensor([0, 0], dtype=torch.float32)
weight.grad = torch.tensor([[-0.5940447, -0.7743838],
[-0.5940447, -0.7743838],
[-0.5940447, -0.7743838]], dtype=torch.float32)
# bias.grad = torch.tensor([-2.5027974, 1.5422692], dtype=torch.float32)
weight_c = weight.clone()
weight_c.grad = weight.grad.clone()
optim = torch.optim.Adafactor([weight])
optim.step()
print(weight)
optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False)
optim_c.step()
print(weight_c)
```
<details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905
Approved by: https://github.com/albanD
Fixes#112592
1) **File: torch/cuda/random.py**
```
Before:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D208: Docstring is over-indented
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D209: Multi-line docstring closing quotes should be on a separate line
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D414: Section has no content ('Args')
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
18
```
```
After:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
1
```
2) **File: torch/cuda/amp/autocast_mode.py**
```
Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
12
```
```
After:
/content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
5
```
3) **File: torch/cuda/amp/grad_scaler.py**
```
Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D400: First line should end with a period (not 'g')
/content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`:
D401: First line should be in imperative mood (perhaps 'Update', not 'Updates')
/content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`:
D401: First line should be in imperative mood (perhaps 'Load', not 'Loads')
/content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`:
D105: Missing docstring in magic method
28
```
```
After:
/content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`:
D105: Missing docstring in magic method
5
```
4) **File: torch/optim/_functional.py**
```
Before:
/content/pytorch/torch/optim/_functional.py:1 at module level:
D400: First line should end with a period (not 'e')
1
```
```
After:
0
```
5) **File: torch/optim/__init__.py**
```
Before:
/content/pytorch/torch/optim/__init__.py:1 at module level:
D205: 1 blank line required between summary line and description (found 0)
1
```
```
After:
0
```
6) **File: torch/optim/lbfgs.py**
```
Before:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D400: First line should end with a period (not 'c')
/content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/lbfgs.py:285 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
5
```
```
After:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`:
D107: Missing docstring in __init__
2
```
7)**File: torch/optim/sparse_adam.py**
```
Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
4
```
```
After:
/content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
3
```
8) **File:torch/optim/adadelta.py**
```
Before:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adadelta.py:82 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
9) **File: torch/optim/adagrad.py**
```
Before:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
/content/pytorch/torch/optim/adagrad.py:100 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`:
D202: No blank lines allowed after function docstring (found 1)
7
```
```
After:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
5
```
10) **File: torch/optim/adam.py**
```
Before:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adam.py:135 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D205: 1 blank line required between summary line and description (found 0)
7
```
```
After:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
11) **File: torch/optim/adamax.py**
```
Before:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamax.py:91 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamax.py:203 in public function `adamax`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
12) **File: torch/optim/adamw.py**
```
Before:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamw.py:153 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamw.py:304 in public function `adamw`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
13) **File: torch/optim/asgd.py**
```
Before:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/asgd.py:107 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/asgd.py:195 in public function `asgd`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts.
Kindly review @svekars @jbschlosser.
In case of any other issues please let me know. Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964
Approved by: https://github.com/kit1980
Summary:
Fixes : https://github.com/pytorch/pytorch/issues/24892
In the paper : https://arxiv.org/pdf/1908.03265.pdf Liyuan Liu et al. suggested a new optimization algorithm with an essence of similar to Adam Algorithm.
It has been discussed in the paper that, without warmup heuristic, in the early stage of adaptive optimization / learning algorithms sometimes we can get undesirable large variance which can slow overall convergence process.
Authors proposed the idea of rectification of variance of adaptive learning rate when it is expected to be high.
Differing from the paper, we selected variance tractability cut-off as 5 instead of 4. This adjustment is common practice, and could be found in the code-repository and also tensorflow swift optim library as well :
2f03dd1970/radam/radam.py (L156)f51ee4618d/Sources/TensorFlow/Optimizers/MomentumBased.swift (L638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58968
Reviewed By: vincentqb
Differential Revision: D29310601
Pulled By: iramazanli
fbshipit-source-id: b7bd487f72f1074f266687fd9c0c6be264a748a9
Summary:
Fixes : https://github.com/pytorch/pytorch/issues/5804
In the paper : https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ Timothy Dozat suggested a new optimization algorithm with an essence of combination of NAG and Adam algorithms.
It is known that the idea of momentum can be improved with the Nesterov acceleration in optimization algorithms, and Dozat is investigating to apply this idea to momentum component of Adam algorithm. Author provided experiment evidence in their work to show excellence of the idea.
In this PR we are implementing the proposed algorithm NAdam in the mentioned paper. Author has a preliminary work http://cs229.stanford.edu/proj2015/054_report.pdf where he shows the decay base constant should be taken as 0.96 which we also followed the same phenomenon here in this implementation similar to Keras. Moreover, implementation / coding practice have been followed similar to Keras in some other places as well:
f9d3868495/tensorflow/python/keras/optimizer_v2/nadam.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59009
Reviewed By: gchanan, vincentqb
Differential Revision: D29220375
Pulled By: iramazanli
fbshipit-source-id: 4b4bb4b15f7e16f7527f368bbf4207ed345751aa
Summary:
Fixes : https://github.com/pytorch/pytorch/issues/24892
In the paper : https://arxiv.org/pdf/1908.03265.pdf Liyuan Liu et al. suggested a new optimization algorithm with an essence of similar to Adam Algorithm.
It has been discussed in the paper that, without warmup heuristic, in the early stage of adaptive optimization / learning algorithms sometimes we can get undesirable large variance which can slow overall convergence process.
Authors proposed the idea of rectification of variance of adaptive learning rate when it is expected to be high.
Differing from the paper, we selected variance tractability cut-off as 5 instead of 4. This adjustment is common practice, and could be found in the code-repository and also tensorflow swift optim library as well :
2f03dd1970/radam/radam.py (L156)f51ee4618d/Sources/TensorFlow/Optimizers/MomentumBased.swift (L638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58968
Reviewed By: gchanan
Differential Revision: D29241736
Pulled By: iramazanli
fbshipit-source-id: 288b9b1f3125fdc6c7a7bb23fde1ea5c201c0448
Summary:
# What is this?
This is an implementation of the AdamW optimizer as implemented in [the fastai library](803894051b/fastai/callback.py) and as initially introduced in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101). It decouples the weight decay regularization step from the optimization step during training.
There have already been several abortive attempts to push this into pytorch in some form or fashion: https://github.com/pytorch/pytorch/pull/17468, https://github.com/pytorch/pytorch/pull/10866, https://github.com/pytorch/pytorch/pull/3740, https://github.com/pytorch/pytorch/pull/4429. Hopefully this one goes through.
# Why is this important?
Via a simple reparameterization, it can be shown that L2 regularization has a weight decay effect in the case of SGD optimization. Because of this, L2 regularization became synonymous with the concept of weight decay. However, it can be shown that the equivalence of L2 regularization and weight decay breaks down for more complex adaptive optimization schemes. It was shown in the paper [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) that this is the reason why models trained with SGD achieve better generalization than those trained with Adam. Weight decay is a very effective regularizer. L2 regularization, in and of itself, is much less effective. By explicitly decaying the weights, we can achieve state-of-the-art results while also taking advantage of the quick convergence properties that adaptive optimization schemes have.
# How was this tested?
There were test cases added to `test_optim.py` and I also ran a [little experiment](https://gist.github.com/mjacar/0c9809b96513daff84fe3d9938f08638) to validate that this implementation is equivalent to the fastai implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21250
Differential Revision: D16060339
Pulled By: vincentqb
fbshipit-source-id: ded7cc9cfd3fde81f655b9ffb3e3d6b3543a4709
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**
This was requested by someone at Facebook; this lint is turned
on for Facebook by default. "Sure, why not."
I had to noqa a number of imports in __init__. Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it. Left for future work.
Be careful! flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments. flake8-3 will
report an import unused; flake8-2 will not. For now, I just
noqa'd all these sites.
All the changes were done by hand.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14687478
fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3