Commit Graph

115 Commits

Author SHA1 Message Date
b412b75b42 [optim] add fused_adam/adamw_kernel support for CPU device (#123074)
On par with `CUDA` implementation.

For `autocast` logic, same with `CUDA` + `Fused Adam`:
 - check inf in `gradscalar.step`
 - In fused kernel, if there is `inf`, do nothing. If not, unscale the grad ( also write back) and update the param.

**TestPlan**:
```
# extend CUDA only test for CPU fused adagrad
python test_optim.py -k test_fused_matches_forloop
python test_optim.py -k test_fused_large_tensor
python test_torch.py -k test_grad_scaling_autocast_fused

# extend fused test
python test_torch.py -k test_params_invalidated_with_grads_invalidated_between_unscale_and_step
python test_optim.py -k test_can_load_older_state_dict

# newly added test (follow 6b1f13ea2f/test/test_cuda.py (L1108))
python test_optim.py -k test_grad_scaling_autocast_fused_optimizers
```

**Benchmark**:
**5.1x** on 56 core SPR
**Parameter-size=1M**
**Nparams=10**
[test script](https://gist.github.com/zhuhaozhe/ef9a290ad3f8f4067b3373a3bdaa33e7)

```
numactl -C 0-55 -m 0 python bench_adam.py
non-fused 6.0174267292022705 s
fused 1.1787631511688232 s
```

**Note: Fused kernel accuracy**
The accuracy failure in CI shows a little higher than default tolerance
```
2024-04-02T06:09:16.2213887Z Mismatched elements: 21 / 64 (32.8%)
2024-04-02T06:09:16.2214339Z Greatest absolute difference: 1.5735626220703125e-05 at index (6, 6) (up to 1e-05 allowed)
2024-04-02T06:09:16.2214813Z Greatest relative difference: 1.0073336852656212e-05 at index (4, 1) (up to 1.3e-06 allowed)
```
I have debug it step by step and unfortunately we may not able to make the `fused kernel` exactly same with `non fused` one due to compiler optimizations.
For example, in non-fused impl
```
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
```
and in fused impl
```
  exp_avg_sq_ptr[d] = scalar_t(beta2) * exp_avg_sq_ptr[d];
  //  std::cout << "exp_avg_sq " <<   exp_avg_sq_ptr[d] << std::endl;
  exp_avg_sq_ptr[d] = exp_avg_sq_ptr[d] +
      scalar_t(exp_avg_sq_grad_coefficient) * grad_val * grad_val;
```
If I keep `std::cout`, I can get exactly same results in UT
```
===============param
0.6796758770942688
0.6796758770942688
```
But when I comment out it, there will be a difference
```
===============param
0.6796758770942688
0.6796759366989136
```
So I will make the tolerance a little higher than default one.

Co-authored-by: Jane Xu <janeyx@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123074
Approved by: https://github.com/jgong5, https://github.com/janeyx99
2024-04-19 11:14:04 +00:00
560efaa471 Part 1: UFMT partial files in torch/optim due to the pr-sanity-checks (#124053)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124053
Approved by: https://github.com/ezyang
ghstack dependencies: #124048
2024-04-16 03:17:18 +00:00
b5ba80828f [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 19:13:00 +00:00
2964170f3a Revert "[optim] Rectify capturable testing and fix bugs! (#118326)"
This reverts commit d947b9d50011ebd75db2e90d86644a19c4fe6234.

Reverted https://github.com/pytorch/pytorch/pull/118326 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it looks like there are some relevant failures in trunk d947b9d500, may be a land race ([comment](https://github.com/pytorch/pytorch/pull/118326#issuecomment-1923125676))
2024-02-02 07:08:14 +00:00
d947b9d500 [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 02:02:58 +00:00
17ecd1e9cd Migrate test_complex_optimizer to OptimizerInfo (#118160)
This PR does what it says and more.

1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
2024-01-24 21:22:47 +00:00
924f1b841a [optim] Allow torch.float64 scalars for forloop + foreach implementations (#115841)
Should allow for uses cases mentioned in #110940

This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.

Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).

This PR will conflict with my other PRs, I will figure out a landing order

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
2023-12-27 09:13:49 +00:00
5b6b680517 Revert "Adamw refactor (#115983)"
This reverts commit eafeba71c1ed35f8cf2d39016bf66c0b088e4a9f.

Reverted https://github.com/pytorch/pytorch/pull/115983 on behalf of https://github.com/jeanschmidt due to Breaking internal tests, @janeyx99 please help @tfsingh to have this PR landed ([comment](https://github.com/pytorch/pytorch/pull/115983#issuecomment-1862976954))
2023-12-19 15:26:44 +00:00
eafeba71c1 Adamw refactor (#115983)
Fixes #104899, refactors adamw by abstracting out common code in adam.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115983
Approved by: https://github.com/janeyx99
2023-12-17 06:58:39 +00:00
62de29d06f [optim] be explicit about CPU scalar tensor dtypes (#111008)
Fixes https://github.com/pytorch/pytorch/issues/110940

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111008
Approved by: https://github.com/janeyx99
2023-11-21 22:44:50 +00:00
a2552d5521 Fixed docstring errors inside torch/cuda/ and torch/optim/ (Docathon H2) (#112964)
Fixes #112592
1) **File: torch/cuda/random.py**
```
Before:
/content/pytorch/torch/cuda/random.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
        D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
        D208: Docstring is over-indented
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
        D209: Multi-line docstring closing quotes should be on a separate line
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
        D414: Section has no content ('Args')
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
        D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
18
```

```
After:
/content/pytorch/torch/cuda/random.py:1 at module level:
        D100: Missing docstring in public module
1

```
2) **File: torch/cuda/amp/autocast_mode.py**
```
Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
        D102: Missing docstring in public method
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
        D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
        D401: First line should be in imperative mood; try rephrasing (found 'Helper')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
        D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
        D401: First line should be in imperative mood; try rephrasing (found 'Helper')
12
```
```
After:
/content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
        D102: Missing docstring in public method
5
```

3)  **File: torch/cuda/amp/grad_scaler.py**
```
Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`:
        D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`:
        D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
        D400: First line should end with a period (not 'g')
/content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`:
        D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`:
        D401: First line should be in imperative mood (perhaps 'Update', not 'Updates')
/content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
        D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
        D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
        D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
        D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
        D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
        D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
        D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
        D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`:
        D401: First line should be in imperative mood (perhaps 'Load', not 'Loads')
/content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`:
        D105: Missing docstring in magic method
28
```
```
After:
/content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`:
        D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`:
        D105: Missing docstring in magic method
5
```
4) **File: torch/optim/_functional.py**
```
Before:
/content/pytorch/torch/optim/_functional.py:1 at module level:
        D400: First line should end with a period (not 'e')
1
```
```
After:
0

```
5) **File: torch/optim/__init__.py**
```
Before:
/content/pytorch/torch/optim/__init__.py:1 at module level:
        D205: 1 blank line required between summary line and description (found 0)
1
```
```
After:
0

```
6) **File: torch/optim/lbfgs.py**
```
Before:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
        D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
        D400: First line should end with a period (not 'c')
/content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/lbfgs.py:285 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
5
```
```
After:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`:
        D107: Missing docstring in __init__
2
```
7)**File: torch/optim/sparse_adam.py**
```
Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
4
```
```
After:
/content/pytorch/torch/optim/sparse_adam.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
        D107: Missing docstring in __init__
3
```
8) **File:torch/optim/adadelta.py**
```
Before:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adadelta.py:82 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`:
        D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
        D105: Missing docstring in magic method
4
```
9) **File: torch/optim/adagrad.py**
```
Before:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
        D102: Missing docstring in public method
/content/pytorch/torch/optim/adagrad.py:100 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`:
        D202: No blank lines allowed after function docstring (found 1)
7
```
```
After:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
        D102: Missing docstring in public method
5
```
10) **File: torch/optim/adam.py**
```
Before:
/content/pytorch/torch/optim/adam.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adam.py:135 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
        D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
        D205: 1 blank line required between summary line and description (found 0)
7
```
```
After:
/content/pytorch/torch/optim/adam.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
        D105: Missing docstring in magic method
4

```
11) **File: torch/optim/adamax.py**
```
Before:
/content/pytorch/torch/optim/adamax.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamax.py:91 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamax.py:203 in public function `adamax`:
        D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamax.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
        D105: Missing docstring in magic method
4
```
12) **File: torch/optim/adamw.py**
```
Before:
/content/pytorch/torch/optim/adamw.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamw.py:153 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamw.py:304 in public function `adamw`:
        D202: No blank lines allowed after function docstring (found 1)
6

```
```
After:
/content/pytorch/torch/optim/adamw.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
        D105: Missing docstring in magic method
4

```
13) **File: torch/optim/asgd.py**
```
Before:
/content/pytorch/torch/optim/asgd.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
        D105: Missing docstring in magic method
/content/pytorch/torch/optim/asgd.py:107 in public method `step`:
        D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/asgd.py:195 in public function `asgd`:
        D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/asgd.py:1 at module level:
        D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
        D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
        D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
        D105: Missing docstring in magic method
4
```
Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts.
Kindly review @svekars @jbschlosser.
In case of any other issues please let me know. Thanks!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964
Approved by: https://github.com/kit1980
2023-11-13 22:16:44 +00:00
f74d766632 feat(optim): use has_complex shortcut flag for all applicable optimizers, use _view_as_real auxiliary function (#110706)
Follow up to: https://github.com/pytorch/pytorch/pull/110607

CC: @lezcano @janeyx99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110706
Approved by: https://github.com/lezcano
2023-10-31 20:33:03 +00:00
93a9b1314b Make step() faster by passing in a tensor vs scalar 1 (#111084)
This is the culminated result of https://github.com/pytorch/pytorch/pull/110954#issuecomment-1758520411.

We are making the code slightly more complicated to gain some perf in minimizing calls to `.copy_()` and `.to()`.

### Code
```
import torch
with torch.cuda.device(0):
    steps = [torch.zeros((), device="cpu", dtype=torch.float32) for i in range(1000)]

    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ]
    ) as p:
        # New code:
        # step_device = steps[0].device
        # one = torch.tensor(1.0, device=step_device) if str(step_device) == "cpu" else 1
        # torch._foreach_add_(steps, one, 1.0)

        # Old code:
        torch._foreach_add_(steps, 1)

    print(p.key_averages().table(sort_by="cpu_time_total"))
```

### Profiles
**with old code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        35.31%      52.089ms        99.99%     147.495ms     147.495ms             1
               aten::add_        25.05%      36.949ms        64.68%      95.406ms      95.406us          1000
                 aten::to         3.97%       5.852ms        39.63%      58.457ms      58.457us          1000
           aten::_to_copy        10.11%      14.917ms        35.66%      52.605ms      52.605us          1000
              aten::copy_        21.65%      31.939ms        21.65%      31.939ms      31.939us          1000
      aten::empty_strided         3.90%       5.749ms         3.90%       5.749ms       5.749us          1000
    cudaDeviceSynchronize         0.01%      18.000us         0.01%      18.000us      18.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 147.513ms
```

**with new code**
```
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      aten::_foreach_add_        55.06%      49.963ms        99.86%      90.625ms      90.625ms             1
               aten::add_        44.81%      40.662ms        44.81%      40.662ms      40.662us          1000
            aten::detach_         0.01%       8.000us         0.05%      45.000us      45.000us             1
                  detach_         0.04%      37.000us         0.04%      37.000us      37.000us             1
              aten::empty         0.03%      30.000us         0.03%      30.000us      30.000us             1
                 aten::to         0.03%      23.000us         0.03%      23.000us      23.000us             1
    cudaDeviceSynchronize         0.02%      22.000us         0.02%      22.000us      22.000us             1
         aten::lift_fresh         0.01%       6.000us         0.01%       6.000us       6.000us             1
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 90.751ms
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111084
Approved by: https://github.com/albanD
ghstack dependencies: #111079
2023-10-20 01:34:08 +00:00
b460c30893 [BE] Enable Ruff's Flake8 PYI042 (#111114)
Enable [snake-case-type-alias (PYI042)](https://docs.astral.sh/ruff/rules/snake-case-type-alias/)

Link: #110950
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111114
Approved by: https://github.com/albanD
2023-10-13 16:33:07 +00:00
d279979102 perf(inductor): improve Adam compile times by shortcutting for loops (via has_complex) (#110607)
Adam part of: https://github.com/pytorch/pytorch/issues/110506

TODO:
- If this approach is validated as a good one, it an also be applied to all other optimizers which convert `complex` via list comprehensions

### Results:
`NUM_PARAMS=200, foreach=True`
- main: dynamo: 43s, inductor: 31s, total: 74s
- this PR: dynamo: 3.5s, inductor: 30s, total: 34s (dynamo speedup: 12.3x, overall speedup: 34s, 2.1x)

`NUM_PARAMS=1000, foreach=True, has_complex shortcut`:

```
<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0329, 50.0806, 0.0041
OutputGraph.call_user_compiler        44.9924
```

`NUM_PARAMS=1000, foreach=True`:
```
<class 'torch.optim.adam.Adam'> {'lr': 0.01, 'foreach': True} torch.float32 TorchDynamo compilation metrics:
Function                              Runtimes (s)
------------------------------------  -------------------------------
_compile.<locals>.compile_inner       0.0389, 58.6069, 0.0043
OutputGraph.call_user_compiler        44.1425
```

### Discussion
- `has_complex` shortcut provides additional 2x dynamo speedup. It is not necessary to achieve a significant overall speedup.

CC: @janeyx99 @mlazos

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110607
Approved by: https://github.com/janeyx99, https://github.com/lezcano
2023-10-06 05:08:49 +00:00
df7d01aed5 perf(inductor): use for loop with shortcut in Optimizers to speedup against list comprehensions (e.g. complex conversion) (#110613)
Fully fixes: https://github.com/pytorch/pytorch/issues/110506

Depends: https://github.com/pytorch/pytorch/pull/110607
Potential merge conflicts:
- https://github.com/pytorch/pytorch/pull/110339
- https://github.com/pytorch/pytorch/pull/110345
- https://github.com/pytorch/pytorch/pull/110454

Related:
- https://github.com/pytorch/pytorch/issues/110606 (we can apply the improvements here orthogonally to the complex support)

### Results

Benchmark: 100 params.

Breakdowns (float32, dynamo):
```
Adagrad: this PR: 4.4s, main: 8.8s
Adam: this PR: 2.1s, main: 9.8s
AdamW: this PR: 2.5s, main: 8.2s
ASGD: this PR: 3.1s, main: 8.5s
RMSProp: this PR: 1.3s, main: 4.2s
RProp: this PR: 6.7s, main: 14.9s
```

Notes:
1. Adagrad is still slow due to `_get_value` list comprehension. Can be fixed in https://github.com/pytorch/pytorch/pull/110339/files by utilizing capturable path
2. Adamax is not actually compiled (it is currently disabled).
3. Inductor compile time is quite variable. We calculate dynamo by subtracting `call_user_compiler` from `compile_inner` timing.

<details>

This PR:
```
Adagrad (torch.float32): 28.47496461868286s
Adagrad (torch.complex64): 29.379547357559204s
Adam (torch.float32): 17.334211587905884s
Adam (torch.complex64): 29.637500524520874s
Adamax (torch.float32): 2.4749321937561035s
Adamax (torch.complex64): 3.1997995376586914s
AdamW (torch.float32): 18.06532859802246s
AdamW (torch.complex64): 28.25661015510559s
ASGD (torch.float32): 23.70255398750305s
ASGD (torch.complex64): 25.33756995201111s
RMSprop (torch.float32): 7.964028596878052s
RMSprop (torch.complex64): 12.909599781036377s
Rprop (torch.float32): 30.512362003326416s
Rprop (torch.complex64): 44.74405765533447s
```

Main
```
Adagrad (torch.float32): 26.919506072998047s
Adagrad (torch.complex64): 35.190622091293335s
Adam (torch.float32): 25.715000867843628s
Adam (torch.complex64): 24.17716670036316s
Adamax (torch.float32): 2.4404726028442383s
Adamax (torch.complex64): 3.3538928031921387s
AdamW (torch.float32): 25.2022807598114s
AdamW (torch.complex64): 28.915700912475586s
ASGD (torch.float32): 24.108731985092163s
ASGD (torch.complex64): 26.589075088500977s
RMSprop (torch.float32): 10.781344175338745s
RMSprop (torch.complex64): 15.136352777481079s
Rprop (torch.float32): 42.46482181549072s
Rprop (torch.complex64): 48.28277635574341s
```

Seems that it doesn't help the complex case by much (but that's not the majority case). torch.float32 is generally positive, when it does not show drastic improvement / regresses, it is due to inductor variance (by manually inspecting the logs).

</details>

### Benchmark Script
```python
import torch
import time
from torch.optim import Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop

OPTIMS = [Adagrad, Adam, Adamax, AdamW, ASGD, RMSprop, Rprop]
DTYPES = [torch.float, torch.cfloat]

NUM_PARAMS = 100
kwargs = { "lr": 0.01, "foreach": True }
summary = []

for optim_cls in OPTIMS:
    for dtype in DTYPES:
        torch._dynamo.reset()
        # torch._inductor.metrics.reset()
        input = torch.ones([10, 10], dtype=dtype, device="cuda:0")
        model = torch.nn.Sequential(
            *[torch.nn.Linear(10, 10, dtype=dtype, device="cuda:0") for _ in range(NUM_PARAMS)]
        )

        model(input).sum().abs().backward()
        opt_compiled = optim_cls(model.parameters(), **kwargs)
        compiled_step = torch.compile(opt_compiled.step)

        with torch.set_grad_enabled(False):
            start_time = time.time()
            compiled_step()
            summary.append(f"{optim_cls.__name__} ({dtype}): {time.time() - start_time}s")

        print(optim_cls, kwargs, dtype, torch._dynamo.utils.compile_times())

for s in summary:
    print(s)
```

CC: @janeyx99 @mlazos
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110613
Approved by: https://github.com/janeyx99
2023-10-05 23:10:52 +00:00
1641d671e5 [optim] FusedAdam/W accepts lr: Tensor without h2ds (#106916)
Starts addressing #106802

This PR also conveniently does some BE:
- Fixes a bug in adamw where we use amsgrad instead of per group amsgrad
- Brings the impls of adamw and adam closer to correctness and to each other

I couldn't fully remove the .pyi's because mypy was going to complain about the entire files which scared me and shouldn't go in this PR anyway.

Test plan:
- Add tests to ensure that lr could be passed as a Tensor
- Did some profiling of the below code (runs 1k iterations of step for Adam)

```
import torch
from torch.testing._internal.common_utils import TestCase

param = torch.rand(2, 3, dtype=torch.float, device='cuda:0', requires_grad=True)
param.grad = torch.rand_like(param)

lr = torch.tensor(.001, device='cuda:0')
opt = torch.optim.Adam([param], lr=lr, fused=True)

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as p:
    for _ in range(1000):
        opt.step()

print(p.key_averages().table(sort_by="cpu_time_total"))

```

Before my change:
<img width="1381" alt="image" src="https://github.com/pytorch/pytorch/assets/31798555/cfc5175a-0f41-4829-941f-342554f3b152">

After my change (notice there are no d2h syncs and the CPU time is lower!):
![image](https://github.com/pytorch/pytorch/assets/31798555/726d7e66-dcff-4a4f-8a75-e84329961989)

Next steps long term:
- have all capturable foreach + forloop impls in Adam(W) handle tensor LR
- have all capturable impls handle tensor LR
- have all impls handle tensor LR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106916
Approved by: https://github.com/albanD
2023-08-21 23:00:44 +00:00
608afe8083 Added xla friendly codepath to single_tensor_adamw (#102858)
There are extra graph compilations on XLA when beta{1,2} ** step get too small. This PR addresses this issue by making the `capturable` interface enabled for XLA, as well as switching to `torch.float_power` which preserves the same behaviour as the non-capturable flow on XLA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102858
Approved by: https://github.com/janeyx99, https://github.com/albanD
2023-08-18 00:16:28 +00:00
21ede4547a remove duplicated code in optimizer (#106022)
Fixes #ISSUE_NUMBER
as the title, the check code  has duplicates
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106022
Approved by: https://github.com/janeyx99
2023-07-26 17:01:28 +00:00
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
e1296a7f8d [Adam] Fix complex x amsgrad support (#104989)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104989
Approved by: https://github.com/albanD
2023-07-21 23:43:26 +00:00
25d80c69ce [foreach] super minor BE: remove unnecessary cast (#105601)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105601
Approved by: https://github.com/albanD
2023-07-20 17:06:52 +00:00
3721fa5612 [BE] Enable ruff's UP rules and autoformat optim/ (#105426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105426
Approved by: https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi, https://github.com/janeyx99
2023-07-18 21:07:43 +00:00
ef05c5f202 Use plain power operator in Adam/Adamw when capturing (#104254)
The goal is to fix the problem from https://github.com/pytorch/pytorch/pull/102858

The full error this used to raise was :
```
2023-06-27T15:12:15.0663239Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adamw.py", line 409, in _single_tensor_adamw
2023-06-27T15:12:15.0663699Z     bias_correction1 = 1 - beta1 ** step
2023-06-27T15:12:15.0664200Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 40, in wrapped
2023-06-27T15:12:15.0664547Z     return f(*args, **kwargs)
2023-06-27T15:12:15.0665031Z   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/_tensor.py", line 882, in __rpow__
2023-06-27T15:12:15.0665483Z     return torch.tensor(other, dtype=dtype, device=self.device) ** self
2023-06-27T15:12:15.0665899Z RuntimeError: CUDA error: operation not permitted when stream is capturing
2023-06-27T15:12:15.0666401Z CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
```

This pow issue was fixed in https://github.com/pytorch/pytorch/pull/104264 and so this problem should be solvable now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104254
Approved by: https://github.com/janeyx99, https://github.com/aws-murandoo
2023-07-13 19:24:25 +00:00
231364fd06 [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Internal CI (and the newly enabled compiled optim tests) will pass after https://github.com/pytorch/pytorch/pull/104866 is landed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-11 14:32:59 +00:00
35f0e35529 [foreach][Adam] Minimize use of intermediates to decrease peak memory (#104780)
Starts addressing https://github.com/pytorch/pytorch/issues/97712 by
- Minimizing intermediates usage for foreach Adam
- Document the extra memory usage
- Add comments within the code for clarity now that we reuse intermediates
- Add tests
- Did some refactoring

Next steps involve doing this for all other foreach implementations. Note that even after this change, foreach mem usage will be higher than forloop due to the fact that we have a minimum budget of 1 intermediate (to not muddle the input values) and the intermediate will be larger. For capturable, the memory usage is higher due to moving more tensors to CUDA.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104780
Approved by: https://github.com/albanD
2023-07-10 17:38:46 +00:00
e7fe2a797c Revert "[optim] use lerp whenever possible (#104796)"
This reverts commit fbe2a7e50a940ba7a12b003241a2699f7a731afb.

Reverted https://github.com/pytorch/pytorch/pull/104796 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/104796#issuecomment-1628591105))
2023-07-10 09:36:41 +00:00
fbe2a7e50a [optim] use lerp whenever possible (#104796)
This is a better copy (with fixes) of #104781.

Test plan:
CI will pass once https://github.com/pytorch/pytorch/pull/104784 is landed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104796
Approved by: https://github.com/albanD
2023-07-08 07:13:38 +00:00
a290cbf32b Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
2023-07-05 23:40:03 +00:00
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
fa893f3f58 Fix optim state_dict casting to allow step to cast to CPU (#102619)
I'm guessing this should fix https://github.com/pytorch/pytorch/pull/88015#issuecomment-1569523106 but am waiting on @ychfan to supply more details so I could write a good test case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102619
Approved by: https://github.com/albanD
2023-06-13 00:46:40 +00:00
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e698378882d30958908073407f97fb3.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
e4a42bcf56 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-07 13:59:20 +00:00
4da88447ea Disable grouping by dtype and device if compiling (#102771)
Disable grouping if we are compiling, this happens during lowering
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102771
Approved by: https://github.com/janeyx99
2023-06-02 21:04:49 +00:00
9d77949b9e Revert "add foreach support for custom device (#102047)"
This reverts commit b088ff467794bc1125133fb0428749d5bcd6ae3a.

Reverted https://github.com/pytorch/pytorch/pull/102047 on behalf of https://github.com/malfet due to Broke inductor, see b088ff4677 ([comment](https://github.com/pytorch/pytorch/pull/102047#issuecomment-1572368942))
2023-06-01 16:33:03 +00:00
b088ff4677 add foreach support for custom device (#102047)
Fixes #ISSUE_NUMBER
for custom device, we want to support foreach, so I add a func that we could set other device type, and the default value is cuda.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102047
Approved by: https://github.com/janeyx99
2023-06-01 06:22:44 +00:00
f558af2a55 [adam] Use the right params in weight_decay, rename for clarity, fixes #100707 (#100973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100973
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-05-09 17:00:27 +00:00
22ea21da3d Change 1D Tensor of 1 element to 0D Tensor (#96994)
add 0d tensor to graph adam/adamw test

Affected:
- `torch.cuda.amp.GradScaler`'s `found_inf`, `_scale`, and `_growth_tracker`
- `step` of Adam & AdamW of `capturable`

Fixes #96776 🤞

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96994
Approved by: https://github.com/janeyx99
2023-03-21 18:24:19 +00:00
7d765cdc66 Fix wrong handling of grad_scale & found_inf in fused optimizers (#95847)
Fixes #95781.
The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf.

Related #94060

I forgot about this wrong handling after #94344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95847
Approved by: https://github.com/janeyx99
2023-03-04 01:21:21 +00:00
75cb99e549 [optim] Widen the cases for defaulting to foreach (#95820)
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.

The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
2023-03-02 04:15:33 +00:00
097679478e [optim] Set defaults to foreach, NOT fused (#95241)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.

Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
2023-02-22 04:47:32 +00:00
3e9df622fb [mta] implement _foreach_pow (#92303)
Mainly for foreach path of `Adam` and `AdamW`

rel: https://github.com/pytorch/pytorch/issues/58833
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92303
Approved by: https://github.com/albanD
2023-02-16 02:28:26 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
6ba041fcae Look up group["capturable"], not defaults["capturable"] in Adam(W) (#94149)
We could set different values in each `param_group` when calling dunder init of `torch.optim` optimizers as in e.g.  https://github.com/pytorch/pytorch/issues/89987.

So check whether or not `capturable` is `True` among all the `param_group`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94149
Approved by: https://github.com/albanD
2023-02-07 00:24:35 +00:00
a23ed38f9a [mta][foreach] Implement fused adamw (#88015)
related: https://github.com/pytorch/pytorch/issues/68041, https://github.com/pytorch/pytorch/issues/71274, https://github.com/pytorch/pytorch/issues/80167
possibly related to https://github.com/pytorch/pytorch/issues/80595#issuecomment-1178519436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88015
Approved by: https://github.com/albanD, https://github.com/ngimel
2023-02-01 19:32:29 +00:00
d7a3f2128f pass None instead of False inside Adam.__setstate__ (#93289)
with a061f139dc, `fused`'s type hint is `Optional[bool]` and its default value is `None`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93289
Approved by: https://github.com/janeyx99, https://github.com/Skylion007
2023-01-31 09:41:35 +00:00
4fc19e1a71 [optim][adam] use fastest impl whenever possible, add util (#93184)
This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor.

To clarify:
* if the user puts True in foreach _only_, it will run the foreach implementation.
* if the user puts True in fused _only_, it will run the fused implementation.
* if the user puts True in foreach AND for fused, it will run the fused implementation.

And:
* if the user puts False in foreach _only_, it will run the single tensor implementation.
* if the user puts False in fused _only_, it will still run the single tensor implementation.
* if the user puts False in foreach AND for fused, it will run the single tensor implementation.

I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI.
```
cuda_only_fp_list = [
    torch.rand((1, 2), device="cuda", dtype=torch.float32),
    torch.rand((1, 2), device="cuda", dtype=torch.float64),
    torch.rand((1, 2), device="cuda", dtype=torch.float16),
    torch.rand((1, 2), device="cuda", dtype=torch.bfloat16),
]

cuda_only_int_list = [
    torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64),
]

cpu_list = [
    torch.rand((1, 2), device="cpu", dtype=torch.float32),
    torch.rand((1, 2), device="cpu", dtype=torch.float64),
    torch.rand((1, 2), device="cpu", dtype=torch.float16),
]

none_list = [None]

# differentiable should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False)

# cpu lists should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False)

# has fused triggers correctly
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True)

# ints always goes to foreach
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True)

# Nones don't error
assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([none_list], False, False) == (False, True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184
Approved by: https://github.com/albanD
2023-01-30 19:58:55 +00:00
de0375e79d [optim][foreach] Do NOT inplace modify gradients (#92706)
SGD and ASGD already had out-of-place grads.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92706
Approved by: https://github.com/ngimel, https://github.com/albanD
2023-01-21 00:12:28 +00:00
e4d83d54a6 Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions

```
[------------------------ (tensors, scalar) -------------------------]
                                   |  without foreach  |  with foreach |    apex
1 threads: ----------------------------------------------------------------------
      10 tensors of size 4         |         120.5     |       61.1    |     50.3
      100 tensors of size 4        |         946.2     |      239.5    |    136.3
      1000 tensors of size 4       |        9808.5     |     2151.1    |   1006.9
      10000 tensors of size 4      |       96871.2     |    22637.4    |  10119.1
      10 tensors of size 16        |         121.0     |       64.1    |     52.5
      100 tensors of size 16       |         993.4     |      252.6    |    136.7
      1000 tensors of size 16      |        9427.7     |     2151.2    |   1049.5
      10000 tensors of size 16     |       97437.1     |    22203.1    |  10340.0
      10 tensors of size 256       |         118.9     |       62.3    |     51.5
      100 tensors of size 256      |         955.2     |      243.1    |    134.2
      1000 tensors of size 256     |        9374.9     |     2140.7    |   1009.6
      10000 tensors of size 256    |       95302.5     |    21849.4    |  10215.5
      10 tensors of size 65536     |         118.5     |       62.4    |     51.1
      100 tensors of size 65536    |        1740.7     |      243.3    |    225.3
      1000 tensors of size 65536   |       17364.1     |     2228.7    |   2004.5
      10000 tensors of size 65536  |      177510.1     |    25410.4    |  20678.2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91846
Approved by: https://github.com/janeyx99
2023-01-20 21:43:29 +00:00