A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**
(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)
## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0,1]
'param_names' ['layer.weight', 'layer.bias'] (optional)
}
]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.
## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.
## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.
#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.
## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.
## Solution Example:
A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
# assuming a single param group.
current_state_group = optimizer.state_dict()['param_groups'][0]
loaded_state_group = state_dict['param_groups'][0]
# same number of params, same names, only different ordering
current_state_name_to_id_mapping = {} # mapping -- param_name: id
for i, name in enumerate(current_state_group['param_names']):
current_state_name_to_id_mapping[name] = current_state_group['params'][i]
# changing the ids of the loaded state dict to match the order of the given state dict.
for i, name in enumerate(current_state_group['param_names']):
loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]
return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.
### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Add the missing documentation for `initial_accumulator_value` parameter in Adagrad, and update the algorithm description in the documentation (adjusted to reflect the implementation).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125886
Approved by: https://github.com/janeyx99
This is the last of the old TestOptim! With this change, everything will be migrated to use OptimizerInfo. Our sparse support is...well, sparse, and the tests try to best encapsulate which configs actually work. Note that support_sparse is actually just supports sparse grads...we don't test sparse params.
1. This PR fixes a bug in Adagrad multi_tensor with maximize by passing the correct value of maximize (vs False everytime) when sparse values are present.
2. This PR does improve coverage. There used to only be 2 configs each, and now we have the following configs for:
Adagrad:
```
python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Adagrad
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'lr': 0.1} <--- this and above are CPU
.{'foreach': False, 'lr': 0.1}
{'foreach': True, 'lr': 0.1}
{'maximize': True, 'foreach': False, 'lr': 0.1}
{'maximize': True, 'foreach': True, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': False, 'lr': 0.1}
{'initial_accumulator_value': 0.1, 'foreach': True, 'lr': 0.1}
.
----------------------------------------------------------------------
Ran 2 tests in 227.744s
OK
```
SGD
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_SGD
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'dampening': 0.5, 'lr': 0.0048}
.{'foreach': False, 'lr': 0.0048}
{'foreach': True, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': False, 'lr': 0.0048}
{'dampening': 0.5, 'foreach': True, 'lr': 0.0048}
.
----------------------------------------------------------------------
Ran 2 tests in 112.801s
OK
```
SparseAdam
```
(pytorch-3.10) [janeyx@devgpu023.odn1 /data/users/janeyx/pytorch (bff23193)]$ python test/test_optim.py -k test_rosenbrock_sparse_with_lrsched_False_Sparse
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
{'maximize': True, 'lr': 0.04}
.{'maximize': True, 'lr': 0.04}
.
----------------------------------------------------------------------
Ran 2 tests in 35.113s
OK
```
Fixes#103322. A side quest in this migration was to re-enable and track dynamo issues as they trigger on the optim tests, which will be complete from this PR. New tests may add more things to track in dynamo, but there is now an established system for doing so, and dynamo is either enabled or a bug is tracked for every migrated test in TestOptimRenewed.
Next steps:
Remove the hyperparameter constraints in common_optimizer.py defined by metadata_for_sparse (other than LR, which seems handpicked for the tests to actually pass). Doing this requires adding more sparse functionality.
Add more tests!
Maybe add more optimizers!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123146
Approved by: https://github.com/albanD
ghstack dependencies: #123134, #123139
This PR does what it says and more.
1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
Should allow for uses cases mentioned in #110940
This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.
Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).
This PR will conflict with my other PRs, I will figure out a landing order
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
Fixes#112592
1) **File: torch/cuda/random.py**
```
Before:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D208: Docstring is over-indented
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D209: Multi-line docstring closing quotes should be on a separate line
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D414: Section has no content ('Args')
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
18
```
```
After:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
1
```
2) **File: torch/cuda/amp/autocast_mode.py**
```
Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
12
```
```
After:
/content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
5
```
3) **File: torch/cuda/amp/grad_scaler.py**
```
Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D400: First line should end with a period (not 'g')
/content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`:
D401: First line should be in imperative mood (perhaps 'Update', not 'Updates')
/content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`:
D401: First line should be in imperative mood (perhaps 'Load', not 'Loads')
/content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`:
D105: Missing docstring in magic method
28
```
```
After:
/content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`:
D105: Missing docstring in magic method
5
```
4) **File: torch/optim/_functional.py**
```
Before:
/content/pytorch/torch/optim/_functional.py:1 at module level:
D400: First line should end with a period (not 'e')
1
```
```
After:
0
```
5) **File: torch/optim/__init__.py**
```
Before:
/content/pytorch/torch/optim/__init__.py:1 at module level:
D205: 1 blank line required between summary line and description (found 0)
1
```
```
After:
0
```
6) **File: torch/optim/lbfgs.py**
```
Before:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D400: First line should end with a period (not 'c')
/content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/lbfgs.py:285 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
5
```
```
After:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`:
D107: Missing docstring in __init__
2
```
7)**File: torch/optim/sparse_adam.py**
```
Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
4
```
```
After:
/content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
3
```
8) **File:torch/optim/adadelta.py**
```
Before:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adadelta.py:82 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
9) **File: torch/optim/adagrad.py**
```
Before:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
/content/pytorch/torch/optim/adagrad.py:100 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`:
D202: No blank lines allowed after function docstring (found 1)
7
```
```
After:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
5
```
10) **File: torch/optim/adam.py**
```
Before:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adam.py:135 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D205: 1 blank line required between summary line and description (found 0)
7
```
```
After:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
11) **File: torch/optim/adamax.py**
```
Before:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamax.py:91 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamax.py:203 in public function `adamax`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
12) **File: torch/optim/adamw.py**
```
Before:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamw.py:153 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamw.py:304 in public function `adamw`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
13) **File: torch/optim/asgd.py**
```
Before:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/asgd.py:107 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/asgd.py:195 in public function `asgd`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts.
Kindly review @svekars @jbschlosser.
In case of any other issues please let me know. Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964
Approved by: https://github.com/kit1980
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.
The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused.
Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95241
Approved by: https://github.com/ngimel
This allows it so that ONLY when the users don't set anything for foreach or fused do we switch the default and cascades adam so that we default to fused, then foreach, then single-tensor.
To clarify:
* if the user puts True in foreach _only_, it will run the foreach implementation.
* if the user puts True in fused _only_, it will run the fused implementation.
* if the user puts True in foreach AND for fused, it will run the fused implementation.
And:
* if the user puts False in foreach _only_, it will run the single tensor implementation.
* if the user puts False in fused _only_, it will still run the single tensor implementation.
* if the user puts False in foreach AND for fused, it will run the single tensor implementation.
I also didn't trust myself that much with the helper function, so I ran some local asserts on _default_to_fused_or_foreach. The only point left to really test is the type(p) -- torch.Tensor but I think the distributed tests will catch that in CI.
```
cuda_only_fp_list = [
torch.rand((1, 2), device="cuda", dtype=torch.float32),
torch.rand((1, 2), device="cuda", dtype=torch.float64),
torch.rand((1, 2), device="cuda", dtype=torch.float16),
torch.rand((1, 2), device="cuda", dtype=torch.bfloat16),
]
cuda_only_int_list = [
torch.randint(1024, (1, 2), device="cuda", dtype=torch.int64),
]
cpu_list = [
torch.rand((1, 2), device="cpu", dtype=torch.float32),
torch.rand((1, 2), device="cpu", dtype=torch.float64),
torch.rand((1, 2), device="cpu", dtype=torch.float16),
]
none_list = [None]
# differentiable should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], True, False) == (False, False)
# cpu lists should always make it return false for both
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, True) == (False, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cpu_list], False, False) == (False, False)
assert _default_to_fused_or_foreach([cpu_list], False, False) == (False, False)
# has fused triggers correctly
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list], False, False) == (False, True)
# ints always goes to foreach
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list], False, False) == (False, True)
# Nones don't error
assert _default_to_fused_or_foreach([cuda_only_fp_list, none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([cuda_only_fp_list, cuda_only_int_list, none_list], False, True) == (False, True)
assert _default_to_fused_or_foreach([none_list], False, True) == (True, False)
assert _default_to_fused_or_foreach([none_list], False, False) == (False, True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93184
Approved by: https://github.com/albanD