This continues the full deprecation after https://github.com/pytorch/pytorch/pull/114425. It's been 6 months! And I'm fairly certain no one is going to yell at me as this patch is not really used.
------
# BC Breaking note
As of this PR, SparseAdam will become consistent with the rest of our optimizers in that it will only accept containers of Tensors/Parameters/param groups and fully complete deprecation of this path. Hitherto, the SparseAdam constructor had allowed raw tensors as the params argument to the constructor. Now, if you write the following code, there will be an error similar to every other optim: "params argument given to the optimizer should be an iterable of Tensors or dicts"
```
import torch
param = torch.rand(16, 32)
optimizer = torch.optim.SparseAdam(param)
```
Instead you should replace the last line with
```
optimizer = torch.optim.SparseAdam([param])
```
to no longer error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127081
Approved by: https://github.com/soulitzer
Is this supposed to be bitwise identical? Wasn't sure how to interpret the comment but it seems to be giving mismatches like:
```
Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 4.6372413635253906e-05 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 3.4600801882334054e-05 at index (1,) (up to 1.3e-06 allowed)
To execute this test, run the following from the base repo dir:
python test/test_optim.py -k test_complex_2d_LBFGS_cpu_complex64
```
on Neoverse-N2 SBSA ARM CPUs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126358
Approved by: https://github.com/lezcano, https://github.com/janeyx99
This PR is meant to address issue #123451, more specifically, the ```test_graph_optims``` and ```test_graph_scaling_fused_optimizers``` functions in ```test_cuda.py``` have been updated so that they now use the new OptimizerInfo infrastructure.
Lintrunner passed:
```
$ lintrunner test/test_cuda.py
ok No lint issues.
```
Tests passed:
```
>python test_cuda.py -k test_graph_optims
Ran 19 tests in 7.463s
OK (skipped=9)
>python test_cuda.py -k test_graph_scaling_fused_optimizers
Ran 6 tests in 2.800s
OK (skipped=3)
```
Both the functions have been moved to the newly created TestCase class ```TestCudaOptims```. The test is mostly the same except the ```@optims``` decorator is used at the top of the function to implicitly call the function using each of the optimizers mentioned in the decorator instead of explicitly using a for loop to iterate through each of the optimizers.
I was unable to use the ```_get_optim_inputs_including_global_cliquey_kwargs``` to get all kwargs for each of the optimizers since some of the kwargs that are used in the original ```test_graph_optims``` function are not being returned by the new OptimizerInfo infrastructure, more specifically, for the ```torch.optim.rmsprop.RMSprop``` optimizer, the following kwargs are not returned whenever ```_get_optim_inputs_including_global_cliquey_kwargs``` is called:
```
{'foreach': False, 'maximize': True, 'weight_decay': 0}
{ 'foreach': True, 'maximize': True, 'weight_decay': 0}
```
I ran into the same issue for ```test_graph_scaling_fused_optimizers```, for the ```torch.optim.adamw.AdamW``` optimizer, whenever ```optim_info.optim_inputs_func(device=device)``` was called, the following kwarg was not returned:
```
{'amsgrad': True}
```
Due to this issue, I resorted to using a dictionary to store the kwargs for each of the optimizers, I am aware that this is less than ideal. I was wondering whether I should use the OptimizerInfo infrastructure to get all the kwargs regardless of the fact that it lacks some kwargs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125127
Approved by: https://github.com/janeyx99
> previous: Originally, the variables `new_eta` and `new_mu` would be constructed `len(grouped_mus)` times, but each of their values is the same and won't be changed. Therefore, it can be simplified using Python list multiplication, which only constructs one tensor.
- [X] Ill assumption that every param will have the same step.
- [x] DIfferent implementation between `foreach=Ture` and `foreach=False`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125440
Approved by: https://github.com/janeyx99
Support fused_sgd_kernel support for CPU.
## Bench result:
32 core/sockets ICX
Test Scripts:
https://gist.github.com/zhuhaozhe/688763e17e93e4c5e12f25f676ec90d9https://gist.github.com/zhuhaozhe/ad9938694bc7fae8b66d376f4dffc6c9
```
Tensor Size: 262144, Num Tensor 4, Num Threads: 1
_single_tensor_sgd time: 0.2301 seconds
_fused_sgd time: 0.0925 seconds
Tensor Size: 4194304, Num Tensor 32, Num Threads: 32
_single_tensor_sgd time: 2.6195 seconds
_fused_sgd time: 1.7543 seconds
```
## Test Plan:
```
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_optim.py -k test_can_load_older_state_dict
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
python test_torch.py -k test_grad_scaling_autocast_fused
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
```
Looks like we already have some PRs under this issue https://github.com/pytorch/pytorch/issues/123451 to unified the UTs, I did not modified UT in this PR.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123629
Approved by: https://github.com/jgong5, https://github.com/janeyx99
On par with `CUDA` implementation.
For `autocast` logic, same with `CUDA` + `Fused Adam`:
- check inf in `gradscalar.step`
- In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.
**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused
# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict
# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```
**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)
```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```
**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
// std::cout << "exp_avg_sq " << exp_avg_sq_ptr[d] << std::endl;
exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.
Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
Removed a bunch of skips, I also updated test_forloop_goes_right_direction to *not* use the closure when dynamo is tracing. The reason for this is that testing the disabled optimizer doesn't actually test anything.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123322
Approved by: https://github.com/janeyx99
ghstack dependencies: #123498
This is the last of the old TestOptim! With this change, everything will be migrated to use OptimizerInfo. Our sparse support is...well, sparse, and the tests try to best encapsulate which configs actually work. Note that support_sparse is actually just supports sparse grads...we don't test sparse params.
1. This PR fixes a bug in Adagrad multi_tensor with maximize by passing the correct value of maximize (vs False everytime) when sparse values are present.
2. This PR does improve coverage. There used to only be 2 configs each, and now we have the following configs for:
Adagrad:
```
python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Adagrad
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'lr': 0.1} <--- this and above are CPU
.{'foreach': False, 'lr': 0.1}
{'foreach': True, 'lr': 0.1}
{'maximize': True, 'foreach': False, 'lr': 0.1}
{'maximize': True, 'foreach': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': False, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': True, 'lr': 0.1}
.
----------------------------------------------------------------------
Ran 2 tests in 227.744s
OK
```
SGD
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_SGD
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'dampening': 0.5, 'lr': 0.0048}
.{'foreach': False, 'lr': 0.0048}
{'foreach': True, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': False, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': True, 'lr': 0.0048}
.
----------------------------------------------------------------------
Ran 2 tests in 112.801s
OK
```
SparseAdam
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Sparse
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.04}
.{'maximize': True, 'lr': 0.04}
.
----------------------------------------------------------------------
Ran 2 tests in 35.113s
OK
```
Fixes#103322. A side quest in this migration was to re-enable and track dynamo issues as they trigger on the optim tests, which will be complete from this PR. New tests may add more things to track in dynamo, but there is now an established system for doing so, and dynamo is either enabled or a bug is tracked for every migrated test in TestOptimRenewed.
Next steps:
Remove the hyperparameter constraints in common_optimizer.py defined by metadata_for_sparse (other than LR, which seems handpicked for the tests to actually pass). Doing this requires adding more sparse functionality.
Add more tests!
Maybe add more optimizers!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123146
Approved by: https://github.com/albanD
ghstack dependencies: #123134, #123139
@tfsingh I got to it first--wanted to land this stack and close the gap ASAP.
This PR also fixes a discrepancy between `_init_group` and `__set_state__` because we have the constants live on params' device always.
There are some next steps though:
- ASGD can be made faster by making etas, mus, steps be on CPU when NOT capturable. (I had mistakenly thought foreachifying was faster and so we landed https://github.com/pytorch/pytorch/pull/107857, but it is slower). No one has complained yet though. ¯\_(ツ)_/¯
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121264
Approved by: https://github.com/albanD
ghstack dependencies: #121260