Commit Graph

311 Commits

Author SHA1 Message Date
fa127d9b20 Fix LBFGS wolfe max iteration (#161488)
Fixes #91581 , based on #135026

## Test Result

```bash
pytest test/test_optim.py

.........
========================== 1473 passed, 242 skipped in 2412.49s (0:40:12) ===========================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161488
Approved by: https://github.com/albanD
2025-09-16 12:07:50 +00:00
74280d0913 [muon] Introduce Muon optimizer to PyTorch (#160213)
A single-device version of Muon. Algorithm refers Keller Jordan's [Muon blogpost](https://kellerjordan.github.io/posts/muon/), and optionally incorporates [Moonshot's](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf) learning rate adjustment strategy.

This implementation maintains a minimalist API and is consistent with other optimizer conventions. PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the [PyTorch examples](https://github.com/pytorch/examples) directory.

**Usage**

```python
    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = adamw_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()
```

~~**Additional usage**~~
~~Users are also able to pass in self-defined `msign` function for orthogonalization, and learning rate adjustment function. Interface defined below:~~

```python
~~AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]~~
~~MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]~~
```

As discussed with team and in comment, we prefer to make the interface simpler and cleaner, thus we removed the callback interface, and canonicalize the original NS algorithm for Muon. The only configs available to users are `ns_steps`, `coefficients`, and `eps`, configurable through kwargs.

By default, we use 5-step Newton-Schulz, with coefficients proposed by [Keller](https://kellerjordan.github.io/posts/muon/). We use LR adjustment proposed by [Moonshot](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf), which grafts learning rate from AdamW.

**Testing**

~~1. Unit tests: the newly introduced Muon is covered in `test/test_optim.py`. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.~~

As discussed, in order not to complicate the codebase, we prefer not to include reference implementation into PyTorch. We also updated the interface so we don't need to test the FQN based filtering. Muon is covered by the existing `test_optim.py` unit test.

2. End-to-end test: we added a training script that pre-trains a QWEN-like model on `openwebtext-100k` dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
<img width="1102" height="472" alt="Screenshot 2025-07-29 at 1 04 12 AM" src="https://github.com/user-attachments/assets/ceab0733-497d-4070-8032-02ae7995c64c" />

**Numerics**
We evaluate our implementation with existing implementation to confirm numerical consistency.

As discussed, our implementation closely follows the algorithm described in [Keller's post](https://kellerjordan.github.io/posts/muon/), while incorporating the learning rate adjustment from [Moonlight](https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf). This captures a key insight that allows users to reuse hyper-parameters tuned for `adamW`, making Muon a drop-in swap.

As expected, the numerics difference mainly comes from `adjust_lr`, a max of ~5% relative diff in an example unit test setup below.

```python
    # dummy model and data
    model0 = Linear(10, 10, bias=False)
    model1 = copy.deepcopy(model0)
    inputs = torch.randn(8, 10)
    targets = torch.randn(8, 10)
    loss = MSELoss()

    lr = 1e-3
    wd = 0.1
    momentum = 0.95

    opt_ref_muon = KellySingleDeviceMuon(
        params=model0.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    opt_exp_muon = Muon(
        params=model1.parameters(),
        lr=lr,
        weight_decay=wd,
        momentum=momentum,
    )

    out_ref = model0(inputs)
    loss_ref = loss(out_ref, targets)
    opt_ref_muon.zero_grad()
    loss_ref.backward()
    opt_ref_muon.step()

    out_exp = model1(inputs)
    loss_exp = loss(out_exp, targets)
    opt_exp_muon.zero_grad()
    loss_exp.backward()
    opt_exp_muon.step()

    for p_ref, p_exp in zip(model0.parameters(), model1.parameters()):
        torch.testing.assert_close(p_ref, p_exp)
```

As explained above, including this `adjust_lr` is preferable. This is validated by an e2e training runs on training a qwen-2-like 0.5b model, where the curves show that training with `adjust_lr` converges more effectively than without.
<img width="1179" height="464" alt="Screenshot 2025-08-18 at 10 12 33 AM" src="https://github.com/user-attachments/assets/e797d3da-c2f0-4187-b99e-5d48b7437c3c" />

**Performance**
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

- adamw_ddp finishes in 13.12 min
- pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is *2.5%* slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
<img width="726" height="590" alt="Screenshot 2025-07-29 at 1 56 14 AM" src="https://github.com/user-attachments/assets/ebcd7e1c-d129-4b20-9396-39f568edf03d" />

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
<img width="751" height="597" alt="Screenshot 2025-07-29 at 2 02 20 AM" src="https://github.com/user-attachments/assets/72f5b904-ebd5-4502-a6ff-d3e9e5a6da81" />

**Note**
We restrict the implementation to accept only 2D parameters.

An alternative approach is to allow parameters with more than two dimensions and apply orthogonalization over the last two dimensions. We opt not to go with this approach as it can be error-prone. For example, with a kernel shaped `[in_channel, height, width, out_channel]`, applying orthogonalization to the last two dimensions is not meaningful.

Since Muon is designed to operate orthogonalization on 2D matrices, preserving this assumption keeps the implementation clean and sound.

**Next Steps**

1. Add `MuP`
2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
3. Open-source unsharded Muon co-designed with FSDP2.

****

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160213
Approved by: https://github.com/janeyx99
2025-08-24 08:03:04 +00:00
4270517cbf Fix test/test_optim.py error message. (#153076)
Fixes an error message in test/test_optim.py

Current behavior: If running the test with Adagrad, the error message reads: "SGD does not currently support capturable".

Fix: The error message now says correctly: "Adagrad does not currently support capturable".
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153076
Approved by: https://github.com/janeyx99
2025-05-07 22:46:05 +00:00
d5b1d99f78 Enable more nightly tests on s390x (#148452)
Also enable some tests which probably were accidentally disabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148452
Approved by: https://github.com/seemethere, https://github.com/malfet
2025-03-18 16:09:39 +00:00
78715a181f Convert Tensor lr to 0-dim as needed for the optimizer to normally work (#145674)
Fixes #145461

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145674
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-03-17 23:07:05 +00:00
c73a92fbf5 [BE][CI] bump ruff to 0.9.2: multiline assert statements (#144546)
Reference: https://docs.astral.sh/ruff/formatter/black/#assert-statements

> Unlike Black, Ruff prefers breaking the message over breaking the assertion, similar to how both Ruff and Black prefer breaking the assignment value over breaking the assignment target:
>
> ```python
> # Input
> assert (
>     len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
>
> # Black
> assert (
>     len(policy_types) >= priority + num_duplicates
> ), f"This tests needs at least {priority+num_duplicates} many types."
>
> # Ruff
> assert len(policy_types) >= priority + num_duplicates, (
>     f"This tests needs at least {priority + num_duplicates} many types."
> )
> ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144546
Approved by: https://github.com/malfet
2025-02-27 20:46:16 +00:00
3908be676c Fix loading older state_dict into AdamW after refactor (#144972)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144972
Approved by: https://github.com/albanD
2025-01-16 19:50:31 +00:00
c8713e659a fix memleak, detach instead of clone to not drag around graph (#144154)
Thanks @clee2000 for bringing the memleak to my attention: https://github.com/pytorch/pytorch/actions/runs/12549765082/job/34996244798.

This memleak in the test was caused by the differentiable flavors. Because we had param.clone() and param persisted outside the for loop, the autograd graph would continue growing for each optimizer.step instead of being deleted after the optim input was used up.

To clarify, I had still expected (and still do expect) the test to fully clean everything up once the test is over, but I didn't get the chance to look into why that's not the case. This change would preliminarily unblock this particular test from failing the memleak CI.

Use detach instead of clone, which is...cheaper anyway :D since a detach I've learned from @soulitzer is a view with requires_grad=False

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144154
Approved by: https://github.com/clee2000, https://github.com/Skylion007, https://github.com/huydhn, https://github.com/albanD
2025-01-06 17:09:00 +00:00
cyy
df458be4e5 [4/N] Apply py39 ruff and pyupgrade fixes (#143257)
```torch/fx/passes/annotate_getitem_nodes.py``` was changed to support the new type hinting annotations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143257
Approved by: https://github.com/justinchuby, https://github.com/albanD
2025-01-04 10:47:51 +00:00
6ccb8ed186 Refactor AdamW into Adam (heavily inspired by tfsingh) (#143710)
Fixes #104899

Refactors AdamW into Adam by making AdamW a subclass of Adam. Additionally adds a test to assert that the added parameter `decoupled_weight_decay` is True in AdamW and also updates test_defaults_changed_to_foreach to account for the differences in module location for AdamW.

Heavily heavily inspired by #118857 by @tfsingh

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143710
Approved by: https://github.com/janeyx99
2024-12-23 23:27:28 +00:00
4e29e4aa63 [BE] Add a test to ensure grads are never inplaced into accidentally (#143612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143612
Approved by: https://github.com/soulitzer
2024-12-20 06:15:08 +00:00
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
e1196dfe51 Deprecate torch._utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-12-08 22:55:36 +00:00
d8b4406e12 [MPS] Expand fused forloop to bfloat16 (#141104)
For MacOS14+

Running following script (adapted from one mentioned in https://github.com/pytorch/pytorch/pull/127242 )
```python
import torch
from torch.optim import adam, adamw
import torch.utils.benchmark as benchmark
import itertools

def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
    fn(
        params,
        grads,
        exp_avgs,
        exp_avg_sqs,
        max_exp_avg_sqs,
        state_steps,
        foreach=False,
        capturable=False,
        fused=fused,
        amsgrad=amsgrad,
        beta1=0.9,
        beta2=0.99,
        lr=1e-3,
        weight_decay=.0,
        eps=1e-5,
        maximize=False,
        grad_scale=None,
        found_inf=None,
    )
    torch.mps.synchronize()

device, dtype = "mps", torch.bfloat16

results = []

for num_tensors, numel, adamWflag, amsgrad in itertools.product([10, 50, 100], [1024, 65536, 1048576], [True, False], [True, False]):
    print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
    params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=dtype, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
    max_exp_avg_sqs = [torch.arange(numel, dtype=dtype, device=device) for _ in range(num_tensors)] if amsgrad else []
    state_steps = [torch.tensor([5], dtype=dtype, device=device) for _ in range(num_tensors)]
    fn = adamw.adamw if adamWflag else adam.adam

    for fused in [True, False]:

        t = benchmark.Timer(
                stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                label=f'Fused Adam on {device} using {dtype}',
                sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                globals=locals(),
                description= f"Fused: {fused}",
            ).blocked_autorange(min_run_time=5)
        results.append(t)

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
2024-11-22 01:07:15 +00:00
989888236e Revert "[MPS] Expand fused forloop to bfloat16 (#141104)"
This reverts commit 9a729390420570cd2528ce2e9947e3eab209660b.

Reverted https://github.com/pytorch/pytorch/pull/141104 on behalf of https://github.com/malfet due to Want to add test script to the commit message ([comment](https://github.com/pytorch/pytorch/pull/141104#issuecomment-2492659931))
2024-11-22 01:03:43 +00:00
9a72939042 [MPS] Expand fused forloop to bfloat16 (#141104)
For MacOS14+

Running following script
```python
```

Produces following results on M4Pro running MacOS 15
```
[-------------------------------- Fused Adam on mps using torch.bfloat16 -------------------------------]
                                                                          |  Fused: True  |  Fused: False
1 threads: ----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 10        |       283     |      2810
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 10       |       277     |      2430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 10       |       285     |      2400
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 10      |       278     |      2250
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 10       |       504     |      2700
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 10      |       478     |      2600
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 10      |       506     |      2500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 10     |       482     |      2300
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 10     |      2089     |      4190
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 10    |      1940     |      3800
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 10    |      2100     |      3770
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 10   |      1950     |      3600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 50        |       842     |     14000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 50       |       835     |     11800
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 50       |       845     |     11700
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 50      |       855     |     11000
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 50       |      1410     |     14000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 50      |      1350     |     12000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 50      |      1400     |     12000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 50     |      1340     |     11000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 50     |      9767     |     20400
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 50    |      8991     |     18600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 50    |      9803     |     18300
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 50   |      9070     |     17600
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100       |      1600     |     27000
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100      |      1600     |     24100
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100      |      1600     |     23500
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100     |      1600     |     21800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100      |      2740     |     26000
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100     |      2580     |     24000
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100     |      2730     |     25000
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100    |      2600     |     23000
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100    |     19350     |     39000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100   |     17780     |     37300
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100   |     19400     |     37000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100  |     17900     |     35500
Times are in microseconds (us).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141104
Approved by: https://github.com/qqaatw, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #141089, #141090, #141092, #141103
2024-11-21 23:30:37 +00:00
a82bab6419 Run only listed tests on s390x (#140265)
Skip tests that are failing

This was previously part of https://github.com/pytorch/pytorch/pull/125401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140265
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-11-20 22:53:09 +00:00
cb71bcc542 Replace clone.detach with detach.clone (#140264)
Fixes #64532

As state in issue, replace `clone.detach` by `detach.clone`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140264
Approved by: https://github.com/soulitzer
2024-11-13 07:01:02 +00:00
1d28b8b6d5 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit e84d1121ad66a453c8c24fcc098625e2e9764fca.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
2024-11-05 23:10:38 +00:00
e84d1121ad Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-05 10:44:56 +00:00
197601eeea Add Support for Tracking Parameter Names (named_parameters) in Optimizer State Dict (#134107)
A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**

(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)

## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
    'state': {
    0: {'momentum_buffer': tensor(...), ...},
    1: {'momentum_buffer': tensor(...), ...},
    },
    'param_groups': [
        {
        'lr': 0.01,
        'weight_decay': 0,
        ...
        'params': [0,1]
        'param_names' ['layer.weight', 'layer.bias']  (optional)
        }
    ]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.

## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.

## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.

#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.

## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.

## Solution Example:

A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
    # assuming a single param group.
    current_state_group = optimizer.state_dict()['param_groups'][0]
    loaded_state_group = state_dict['param_groups'][0]

    # same number of params, same names, only different ordering
    current_state_name_to_id_mapping = {}  # mapping --  param_name: id
    for i, name in enumerate(current_state_group['param_names']):
        current_state_name_to_id_mapping[name] = current_state_group['params'][i]

    # changing the ids of the loaded state dict to match the order of the given state dict.
    for i, name in enumerate(current_state_group['param_names']):
        loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]

    return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.

### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-10-14 19:24:44 +00:00
702c810780 move param's device check to _init_group for fused (#131153)
There could be some cases where the params have the meta device when calling optimizer's dunder init and those params are materialized in the first computation. This change would allow such situation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131153
Approved by: https://github.com/mlazos, https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2024-08-17 04:49:47 +00:00
c23dceb8f1 Add Adafactor foreach impl (#132336)
This PR adds the foreach impl for Adafactor knowing that there are many ways to improve its runtime perf today (by adding more foreach support). After this PR:
- we have a foreach flag for Adafactor
- It is NOT the default. Why not? It is only slightly faster + uses O(n) more memory where n is the number of params in your max param group. People tend to use Adafactor for memory efficiency.

Next steps:
- make torch.compile possible on it
- make it faster (by adding more foreach apis)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132336
Approved by: https://github.com/albanD
ghstack dependencies: #133360
2024-08-15 17:00:33 +00:00
e7d8d73582 [minor] Correct in-code documentation for complex numbers and LBFGS (#133020)
To @lezcano's credit, this should be associative, as floating point add is actually commutative per IEEE754.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133020
Approved by: https://github.com/soulitzer, https://github.com/lezcano
2024-08-12 18:04:11 +00:00
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7cec82a43f2de52b83eff152d703be7a3.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
9c4cf866c2 Adafactor forloop basic impl (#129905)
#109581

At this point, the vanilla implementation (the default) is good.
Docs: https://docs-preview.pytorch.org/pytorch/pytorch/129905/generated/torch.optim.Adafactor.html#torch.optim.Adafactor

Specifically, the impl in this PR, which attempts to replicate the paper,
```
optim = torch.optim.Adafactor([weight])
```
is close enough to https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.AdaFactor
```
optim_c = AdaFactor([weight], betas=(0, 0.999), scale_parameter=False)
```
is close enough to https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adafactor
```
optim = keras.optimizers.Adafactor(learning_rate=0.01)
```

The three results respectively for the same randomly generated weights:
```
# ours
tensor([[ 0.3807594, -0.3912092],
        [ 0.0762539,  0.5377805],
        [ 0.2459473,  0.4662207]])

# pytorch-optimizer
tensor([[ 0.3807592, -0.3912172],
        [ 0.0762507,  0.5377818],
        [ 0.2459457,  0.4662213]])

# keras
array([[ 0.38076326, -0.39121315],
        [ 0.0762547 ,  0.5377859 ],
        [ 0.24594972,  0.46622536]], dtype=float32)
```

This gives me confidence to move forward in speeding up the implementation now that a baseline has been established. If you're curious about differences:
* keras assigns step_size (rho_t in their code) to `min(lr, 1 / sqrt(step)` whereas the OG impl uses a hardcoded 0.01 instead of lr. We do the same thing as keras, but our lr default is 0.01.
* We differ from the pytorch-optimizers default in that our default will not track momentum (thus `beta1=0`) and we do not apply parameter scaling.

<details>

Keras collab: https://colab.research.google.com/drive/1i3xF8ChL7TWKJGV_5v_5nMhXKnYmQQ06?usp=sharing

My script repro:

```
import torch
from pytorch_optimizer import AdaFactor
torch.set_printoptions(precision=7)

weight = torch.tensor([[ 0.37697506, -0.39500135],
        [ 0.07246649,  0.53399765],
        [ 0.24216151,  0.46243715]], dtype=torch.float32)
# bias = torch.tensor([0, 0], dtype=torch.float32)

weight.grad = torch.tensor([[-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838],
        [-0.5940447, -0.7743838]], dtype=torch.float32)
# bias.grad = torch.tensor([-2.5027974,  1.5422692], dtype=torch.float32)

weight_c = weight.clone()
weight_c.grad = weight.grad.clone()

optim = torch.optim.Adafactor([weight])
optim.step()
print(weight)

optim_c = AdaFactor([weight_c], betas=(0, 0.999), scale_parameter=False)
optim_c.step()
print(weight_c)
```

<details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129905
Approved by: https://github.com/albanD
2024-07-25 13:17:19 +00:00
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
e57101d927 Add testing regarding SparseAdam state_dicts (#130645)
Summary:
- Updated SparseAdam to run test_state_dict_deterministic unit test.
- Made gradients sparse while keeping weights dense in the above test.

Test Plan:
- Ran test_optim.py locally.

Fixes #116507

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130645
Approved by: https://github.com/janeyx99
2024-07-16 11:29:22 +00:00
d62d351107 [Optim][BE] Change str(device) to _get_device_type(device) (#129984)
Prevent using vague expressions like `"cuda" in str(device)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129984
Approved by: https://github.com/janeyx99
ghstack dependencies: #129451, #129552
2024-07-04 06:44:48 +00:00
9a7e2519d3 [MPS] Fused Adam & AdamW (#127242)
Summary:

This PR adds fused Adam and AdamW implementations.

Benchmark on Macbook Pro with M1 Max chip and 64GB unified memory:
**Fast math enabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        89
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        90
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        83
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       12      |        94
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       11      |        88
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       12      |        90
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |       100
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       27      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       23      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       27      |       100
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       23      |        98
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       82      |       480
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       72      |       450
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       82      |       450
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       73      |       420
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       91      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       83      |       400
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |       94      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       78      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      170      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      140      |       600
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      170      |       600
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      140      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      250      |       890
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      220      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      250      |       830
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      220      |       770
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      270      |       870
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      230      |       840
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      270      |       810
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      240      |       800
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      400      |      1000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      360      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      430      |      2000
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      360      |      1300

Times are in milliseconds (ms).
```

**Fast math disabled:**
```
[---------------------------------------------- Fused Adam ----------------------------------------------]
                                                                           |  Fused: True  |  Fused: False
1 threads: -----------------------------------------------------------------------------------------------
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 100        |       10      |       100
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 100       |        9      |        84
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 100      |        9      |        79
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 100       |       11      |        93
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 100      |       10      |        90
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 100      |       11      |        91
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 100     |       11      |        81
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 100     |       34      |       100
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 100    |       31      |       100
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 100    |       34      |        95
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 100   |       31      |       100
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 500        |       94      |       500
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 500       |       82      |       430
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 500       |       92      |       430
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 500      |       81      |       390
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 500       |       98      |       500
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 500      |       88      |       430
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 500      |      100      |       500
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 500     |       88      |       400
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 500     |      210      |       500
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 500    |      190      |       610
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 500    |      210      |       510
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 500   |      190      |       500
      amsgrad: True, adamWflag: True, numel: 1024, num_tensors: 1000       |      300      |       900
      amsgrad: False, adamWflag: True, numel: 1024, num_tensors: 1000      |      260      |       850
      amsgrad: True, adamWflag: False, numel: 1024, num_tensors: 1000      |      295      |       900
      amsgrad: False, adamWflag: False, numel: 1024, num_tensors: 1000     |      260      |       800
      amsgrad: True, adamWflag: True, numel: 65536, num_tensors: 1000      |      320      |       910
      amsgrad: False, adamWflag: True, numel: 65536, num_tensors: 1000     |      280      |       900
      amsgrad: True, adamWflag: False, numel: 65536, num_tensors: 1000     |      320      |       900
      amsgrad: False, adamWflag: False, numel: 65536, num_tensors: 1000    |      300      |       900
      amsgrad: True, adamWflag: True, numel: 1048576, num_tensors: 1000    |      500      |      2000
      amsgrad: False, adamWflag: True, numel: 1048576, num_tensors: 1000   |      480      |      2000
      amsgrad: True, adamWflag: False, numel: 1048576, num_tensors: 1000   |      540      |      1500
      amsgrad: False, adamWflag: False, numel: 1048576, num_tensors: 1000  |      480      |      1200

Times are in milliseconds (ms).
```

```python
def profile_fused_adam():
    from torch.optim import adam, adamw
    import torch.utils.benchmark as benchmark

    import itertools

    def profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused):
        fn(
            params,
            grads,
            exp_avgs,
            exp_avg_sqs,
            max_exp_avg_sqs,
            state_steps,
            foreach=False,
            capturable=False,
            fused=fused,
            amsgrad=amsgrad,
            beta1=0.9,
            beta2=0.99,
            lr=1e-3,
            weight_decay=.0,
            eps=1e-5,
            maximize=False,
            grad_scale=None,
            found_inf=None,
        )
        torch.mps.synchronize()

    device = "mps"

    results = []

    for num_tensors, numel, adamWflag, amsgrad in itertools.product([100, 500, 1000], [1024, 65536, 1048576], [True, False], [True, False]):
        print(f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}")
        params, grads, exp_avgs, exp_avg_sqs = [[torch.arange(numel, dtype=torch.float32, device=device) + (numel * i) for i in range(num_tensors)] for _ in range(4)]
        max_exp_avg_sqs = [torch.arange(numel, dtype=torch.float32, device=device) for _ in range(num_tensors)] if amsgrad else []
        state_steps = [torch.tensor([5], dtype=torch.float32, device=device) for _ in range(num_tensors)]
        if adamWflag:
            fn = adamw.adamw
        else:
            fn = adam.adam

        for fused in [True, False]:

            t = benchmark.Timer(
                    stmt='profile(fn, params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, fused)',
                    label='Fused Adam',
                    sub_label=f"amsgrad: {amsgrad}, adamWflag: {adamWflag}, numel: {numel}, num_tensors: {num_tensors}",
                    globals=locals(),
                    description= f"Fused: {fused}",
                ).blocked_autorange(min_run_time=5)
            results.append(t)

    compare = benchmark.Compare(results)
    compare.trim_significant_figures()
    compare.colorize(rowwise=True)
    compare.print()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127242
Approved by: https://github.com/kulinseth, https://github.com/janeyx99
2024-06-18 19:59:50 +00:00
90bb510ece Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 348b181a97abc2e636a6c18e5880a78e5d1dab94.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/clee2000 due to sorry I think https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 is still relevant, I will reach out to them to see what needs to be done in internal to get this remerged ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2159248859))
2024-06-10 20:44:42 +00:00
348b181a97 Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007
2024-06-08 15:25:03 +00:00
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0a8325cbad4734a563aa459ca611991.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
da39461d61 [optim] Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py (#126418)
this PR address the comments in this PR #124904

- Move test_grad_scaling_autocast_fused_optimizers to test_cuda.py
- Combine _grad_scaling_autocast_fused_optimizers into test_grad_scaling_autocast_fused_optimizers
- Move to OptimizerInfo framework.
- For failing tests test_grad_scaling_autocast_fused_optimizers AdamW_cuda_float32, Adam_cuda_float32
    - Added toleranceOverride in this PR
    - created a issue #127000

```
> (c2env) [sandish@devgpu166.ash6 ~/pytorch (refactoroptimizers)]$ python test/test_cuda.py -k test_grad_scaling_autocast_fused_optimizers -v
/home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.
  warnings.warn(
/home/sandish/pytorch/torch/backends/cudnn/__init__.py:106: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.
  warnings.warn(
test_grad_scaling_autocast_fused_optimizers_Adagrad_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True}
{'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'lr': 0.1, 'fused': True}
{'lr': 0.1, 'fused': True}
{'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True}
{'initial_accumulator_value': 0.1, 'weight_decay': 0.1, 'fused': True}
{'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True}
{'lr': 0.1, 'lr_decay': 0.5, 'weight_decay': 0.1, 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_AdamW_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_Adam_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_SGD_cpu_float32 (__main__.TestCudaOptimsCPU) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
{'momentum': 0.9, 'fused': True}
{'momentum': 0.9, 'fused': True}
{'momentum': 0.9, 'dampening': 0.5, 'fused': True}
{'momentum': 0.9, 'dampening': 0.5, 'fused': True}
{'momentum': 0.9, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_Adagrad_cuda_float32 (__main__.TestCudaOptimsCUDA) ... skipped 'cuda is not supported for fused on Adagrad'
test_grad_scaling_autocast_fused_optimizers_AdamW_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'capturable': True, 'fused': True}
{'capturable': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True}
{'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True}
{'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_Adam_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'fused': True}
{'capturable': True, 'fused': True}
{'capturable': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True}
{'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'fused': True}
{'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True}
{'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'fused': True}
ok
test_grad_scaling_autocast_fused_optimizers_SGD_cuda_float32 (__main__.TestCudaOptimsCUDA) ... {'fused': True}
{'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': 0.01, 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
{'lr': tensor(0.0010), 'fused': True}
{'momentum': 0.9, 'fused': True}
{'momentum': 0.9, 'fused': True}
{'momentum': 0.9, 'dampening': 0.5, 'fused': True}
{'momentum': 0.9, 'dampening': 0.5, 'fused': True}
{'momentum': 0.9, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True}
{'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
{'weight_decay': 0.1, 'maximize': True, 'fused': True}
ok

----------------------------------------------------------------------
Ran 8 tests in 16.117s

OK (skipped=1)

> lintrunner test/test_cuda.py
----------------------------------------------------------------------
ok No lint issues.

> lintrunner torch/testing/_internal/common_optimizers.py
----------------------------------------------------------------------
ok No lint issues.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126418
Approved by: https://github.com/janeyx99
2024-05-30 01:47:41 +00:00
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
a6b994ed54 Fix lint after #126845 (#127286)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127286
Approved by: https://github.com/NicolasHug, https://github.com/DanilBaibak
2024-05-28 12:38:27 +00:00
8979412442 Enable ufmt format on test files (#126845)
Fixes some files in  #123062

Run lintrunner on files:

test/test_nnapi.py,
test/test_numba_integration.py,
test/test_numpy_interop.py,
test/test_openmp.py,
test/test_optim.py

```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126845
Approved by: https://github.com/ezyang
2024-05-28 01:42:07 +00:00
7e166e8057 [optim] Fix: wrong ASGD implementation (#126375)
This PR is based on #125440, additionally merging the latest main branch and fixing the lint failures from #126361.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126375
Approved by: https://github.com/janeyx99
2024-05-17 15:46:39 +00:00
e3c5d1b7d7 Revert "[optim] Fix: wrong ASGD implementation (#125440)"
This reverts commit 2c5ad9a3d7ea79ca897aec153a401f4b9175a717.

Reverted https://github.com/pytorch/pytorch/pull/125440 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it looks like there is a linter failure coming from this change ([comment](https://github.com/pytorch/pytorch/pull/125440#issuecomment-2113833108))
2024-05-16 02:12:29 +00:00
f9d107af66 [optim] add fused_adagrad support for CPU device (#124905)
Support fused_sgd_kernel support for CPU.

## Bench result:
32 core/sockets ICX
Test Scripts:
https://gist.github.com/zhuhaozhe/79e842e0a6e25d6d7fa1e4598807272c
https://gist.github.com/zhuhaozhe/b4c6998a509dcea1796dd05b3005c969
```
Tensor Size: 262144, Num Tensor 4, Num Threads: 1
_single_tensor_adagrad time: 0.2500 seconds
_fused_adagrad time: 0.0933 seconds
Tensor Size: 4194304, Num Tensor 32, Num Threads: 32
_single_tensor_adagrad time: 2.8819 seconds
_fused_adagrad time: 1.7591 seconds
```
## Test Plan:
```
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_optim.py -k test_can_load_older_state_dict
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
python test_torch.py -k test_grad_scaling_autocast_fused
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
```

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124905
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-05-16 01:11:51 +00:00
2c5ad9a3d7 [optim] Fix: wrong ASGD implementation (#125440)
> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor.

- [X] Ill assumption that every param will have the same step.
- [x] DIfferent implementation between `foreach=Ture` and `foreach=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440
Approved by: https://github.com/janeyx99
2024-05-15 22:52:15 +00:00
bd3cbdba2f Revert "[optim] add fused_adagrad support for CPU device (#124905)"
This reverts commit 1c3fe8403365db3cc9b75524ae742e3027b745e2.

Reverted https://github.com/pytorch/pytorch/pull/124905 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing distributed multigpu test in trunk 1c3fe84033 ([comment](https://github.com/pytorch/pytorch/pull/124905#issuecomment-2108777063))
2024-05-13 20:53:22 +00:00
1c3fe84033 [optim] add fused_adagrad support for CPU device (#124905)
Support fused_sgd_kernel support for CPU.

## Bench result:
32 core/sockets ICX
Test Scripts:
https://gist.github.com/zhuhaozhe/79e842e0a6e25d6d7fa1e4598807272c
https://gist.github.com/zhuhaozhe/b4c6998a509dcea1796dd05b3005c969
```
Tensor Size: 262144, Num Tensor 4, Num Threads: 1
_single_tensor_adagrad time: 0.2500 seconds
_fused_adagrad time: 0.0933 seconds
Tensor Size: 4194304, Num Tensor 32, Num Threads: 32
_single_tensor_adagrad time: 2.8819 seconds
_fused_adagrad time: 1.7591 seconds
```
## Test Plan:
```
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_optim.py -k test_can_load_older_state_dict
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
python test_torch.py -k test_grad_scaling_autocast_fused
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
```

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124905
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-05-13 01:16:20 +00:00
b24ad7eab5 Enable dynamo traced test_param_group_with_lrscheduler_goes_right_direction (#124544)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124544
Approved by: https://github.com/janeyx99
ghstack dependencies: #125825, #125826
2024-05-11 06:29:59 +00:00
69eeef0727 Update LRScheduler to handle tensor LR (#123753)
Enables LRScheduler to handle tensor LRs.

Note on test changes:
For the test modifications I just removed itertools.product and created two loops. This allows us to create a new set of optim_inputs on each iteration to prevent mutations on the tensor LR carrying over across iterations. Nothing else in those tests was modified.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123753
Approved by: https://github.com/janeyx99
ghstack dependencies: #123751, #123752
2024-05-09 00:52:43 +00:00
f0c6d6100b Enable dynamo-traced optimizer peak memory tests (#124543)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124543
Approved by: https://github.com/yf225, https://github.com/janeyx99
2024-05-07 08:21:50 +00:00
489b4586e9 [optim]fix ut and sgd kernel (#124904)
- Original `test_grad_scaling_autocast_fused_optimizers` does not work since there is no "fused" in `optim_inputs`
 - We should use different `grad_scaler`, they should not share 1 `scale`, there is no issue exposed here because the default `_growth_interval` is 2000 so it will not growth and there is also no inf is found so it will not reduced. The one in `test_cuda.py` should also have this issue,
 - I set a manual seed to reproduce purpose if there is any numerical failure
 - I use Tensor tracker here because we failed this UT in dynamo case, the cpp generated code are not exactly same with fused/non fused kernel.
 - I make it check both `cuda` and `cpu`.
 - I find some SGD numerical issue with `clang`, and fixed it by using `fmadd` instead of `add/mul` in fused sgd veckernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124904
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-05-03 09:13:24 +00:00
68a027f144 Fixes for 123400 (#123406)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123406
Approved by: https://github.com/janeyx99
ghstack dependencies: #123324, #123404, #123405, #124309
2024-04-19 17:20:57 +00:00