A proposal addressing Issue #1489: **Optimizer should track parameter names and not id.**
(also mentioned in here: [[RFC] Introducing FQNs/clarity eyeglasses to optim state_dict](https://dev-discuss.pytorch.org/t/rfc-introducing-fqns-clarity-to-optim-state-dict/1552)
## Summary
This PR introduces a backward-compatible enhancement where optimizers track parameter names instead of just their id.
Optimizers can be initialized with `named_parameters()` as:
```python
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
```
This allows for greater clarity and ease when handling optimizers, as the parameters' names are preserved within the optimizer’s `state_dict` as:
```
state_dict =
{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0,1]
'param_names' ['layer.weight', 'layer.bias'] (optional)
}
]
}
```
Loading `state_dict` is not changed (backward-compatible) and the `param_names` key will be ignored.
## Key Features
#### Named Parameters in Optimizer Initialization:
Optimizers can accept the output of `model.named_parameters()` during initialization, allowing them to store parameter names directly.
#### Parameter Names in `state_dict`:
The parameter names are saved as a list in the optimizer’s `state_dict` with key `param_names`, alongside the `params` indices, ensuring seamless tracking of both names and parameters.
## Backward Compatibility
#### No Breaking Changes:
This change is fully backward-compatible. The added `param_names` key in the optimizer's `state_dict` is ignored when loading a state to the optimizer.
#### Customization with Hooks:
For more control, the loaded state_dict can be modified using a custom `register_load_state_dict_pre_hook`, providing flexibility for different design needs.
## Documentation Updates
Please refer to the documentation changes for more details on how this feature is implemented and how it can be used effectively.
## Solution Example:
A suggested solution to the problem mentioned in #1489, for the same parameters but in a different order.
The following `register_load_state_dict_pre_hook` should be added to the optimizer before loading to enable loading the state dict :
```python
def adapt_state_dict_ids(optimizer, state_dict):
# assuming a single param group.
current_state_group = optimizer.state_dict()['param_groups'][0]
loaded_state_group = state_dict['param_groups'][0]
# same number of params, same names, only different ordering
current_state_name_to_id_mapping = {} # mapping -- param_name: id
for i, name in enumerate(current_state_group['param_names']):
current_state_name_to_id_mapping[name] = current_state_group['params'][i]
# changing the ids of the loaded state dict to match the order of the given state dict.
for i, name in enumerate(current_state_group['param_names']):
loaded_state_group['params'][i] = current_state_name_to_id_mapping[name]
return state_dict
```
In this code, the loaded `state_dict` ids are adapted to match the order of the current optimizer `state_dict`.
Both the previous and the current optimizers are required to be initiated with `named_parameters()` to have the 'param_names' key in the dict.
### Note
This is my first contribution to PyTorch, and I wish to receive feedback or suggestions for improvement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134107
Approved by: https://github.com/janeyx99
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
This resolves a bug in eager where if an old state dict is loaded (without the capturable flag) but the original dict had the capturable flag, then state_steps would be on cuda but we would take the non-capturable path. We now fallback to eager if capturable=False.
Current design doc and discussion: https://docs.google.com/document/d/1DmmbiaSp16CDZtGw1qzXKHFTY_0gqc0xpnBdviXq0vk/edit#heading=h.871u7bvwz7ze
Note on the actual fallback logic - there was an issue with torchscript originally not handling *args, **kwargs properly, after rectifying that by using `functools.wraps`, there was an additional bug with scoping which required the single tensor implementation to be in the global scope at the time of the fallback closure being created. I pass in the single tensor function to the `_disable_dynamo_if_unsupported` decorator to workaround this bug.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123619
Approved by: https://github.com/janeyx99
Finishes the work started in https://github.com/pytorch/pytorch/pull/118697. Thanks @MarouaneMaatouk for the attempt, but due to inactivity I have opened this PR for Adamax. Note that the new capturable implementation is much simpler and I've modified the foreach capturable impl--it now calls fewer kernels and is more easily comparable to forloop.
Next steps:
* This PR discovered two bugs: #121178 and #121238.
* Move the now hefty graph optim tests in test_cuda to use OptimInfo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121183
Approved by: https://github.com/albanD
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.
The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.
Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s
OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.
The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.
Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s
OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
This PR does what it says and more.
1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
Should allow for uses cases mentioned in #110940
This would allow scalars to also be float64s in the foreach implementation. The fused implementation would still create a float32 step on Adam and AdamW. This PR also does NOT worry about performance and is mainly for enablement.
Next steps:
- Relax the constraint on fused adam(w) and allow torch.float64 scalars there
- Allow _performant_ mixed dtypes in foreach (a bigger project in itself).
This PR will conflict with my other PRs, I will figure out a landing order
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115841
Approved by: https://github.com/albanD
Fixes#112592
1) **File: torch/cuda/random.py**
```
Before:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/random.py:21 in public function `get_rng_state`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/cuda/random.py:43 in public function `get_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/random.py:54 in public function `set_rng_state`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D208: Docstring is over-indented
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D209: Multi-line docstring closing quotes should be on a separate line
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:79 in public function `set_rng_state_all`:
D414: Section has no content ('Args')
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:88 in public function `manual_seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:110 in public function `manual_seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:128 in public function `seed`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/random.py:146 in public function `seed_all`:
D401: First line should be in imperative mood (perhaps 'Set', not 'Sets')
/content/pytorch/torch/cuda/random.py:167 in public function `initial_seed`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
18
```
```
After:
/content/pytorch/torch/cuda/random.py:1 at module level:
D100: Missing docstring in public module
1
```
2) **File: torch/cuda/amp/autocast_mode.py**
```
Before: /content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:18 in public class `autocast`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:90 in public function `custom_fwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D400: First line should end with a period (not 'f')
/content/pytorch/torch/cuda/amp/autocast_mode.py:130 in public function `custom_bwd`:
D401: First line should be in imperative mood; try rephrasing (found 'Helper')
12
```
```
After:
/content/pytorch/torch/cuda/amp/autocast_mode.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/autocast_mode.py:23 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/autocast_mode.py:38 in public method `__enter__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:44 in public method `__exit__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/autocast_mode.py:49 in public method `__call__`:
D102: Missing docstring in public method
5
```
3) **File: torch/cuda/amp/grad_scaler.py**
```
Before: /content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:17 in private class `_MultiDeviceReplicator`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:39 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:50 in public class `GradScaler`:
D400: First line should end with a period (not 'g')
/content/pytorch/torch/cuda/amp/grad_scaler.py:115 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:354 in public method `step`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:456 in public method `update`:
D401: First line should be in imperative mood (perhaps 'Update', not 'Updates')
/content/pytorch/torch/cuda/amp/grad_scaler.py:529 in public method `get_scale`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:544 in public method `get_growth_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:550 in public method `set_growth_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:557 in public method `get_backoff_factor`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:563 in public method `set_backoff_factor`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:570 in public method `get_growth_interval`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/cuda/amp/grad_scaler.py:576 in public method `set_growth_interval`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D200: One-line docstring should fit on one line with quotes (found 3)
/content/pytorch/torch/cuda/amp/grad_scaler.py:592 in public method `is_enabled`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D400: First line should end with a period (not ':')
/content/pytorch/torch/cuda/amp/grad_scaler.py:598 in public method `state_dict`:
D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/content/pytorch/torch/cuda/amp/grad_scaler.py:624 in public method `load_state_dict`:
D401: First line should be in imperative mood (perhaps 'Load', not 'Loads')
/content/pytorch/torch/cuda/amp/grad_scaler.py:649 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:665 in public method `__setstate__`:
D105: Missing docstring in magic method
28
```
```
After:
/content/pytorch/torch/cuda/amp/grad_scaler.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/cuda/amp/grad_scaler.py:40 in public class `OptState`:
D101: Missing docstring in public class
/content/pytorch/torch/cuda/amp/grad_scaler.py:117 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/cuda/amp/grad_scaler.py:647 in public method `__getstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/cuda/amp/grad_scaler.py:663 in public method `__setstate__`:
D105: Missing docstring in magic method
5
```
4) **File: torch/optim/_functional.py**
```
Before:
/content/pytorch/torch/optim/_functional.py:1 at module level:
D400: First line should end with a period (not 'e')
1
```
```
After:
0
```
5) **File: torch/optim/__init__.py**
```
Before:
/content/pytorch/torch/optim/__init__.py:1 at module level:
D205: 1 blank line required between summary line and description (found 0)
1
```
```
After:
0
```
6) **File: torch/optim/lbfgs.py**
```
Before:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D205: 1 blank line required between summary line and description (found 0)
/content/pytorch/torch/optim/lbfgs.py:185 in public class `LBFGS`:
D400: First line should end with a period (not 'c')
/content/pytorch/torch/optim/lbfgs.py:215 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/lbfgs.py:285 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
5
```
```
After:
/content/pytorch/torch/optim/lbfgs.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/lbfgs.py:217 in public method `__init__`:
D107: Missing docstring in __init__
2
```
7)**File: torch/optim/sparse_adam.py**
```
Before: /content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/sparse_adam.py:40 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
4
```
```
After:
/content/pytorch/torch/optim/sparse_adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/sparse_adam.py:7 in public class `SparseAdam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/sparse_adam.py:8 in public method `__init__`:
D107: Missing docstring in __init__
3
```
8) **File:torch/optim/adadelta.py**
```
Before:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adadelta.py:82 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adadelta.py:193 in public function `adadelta`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adadelta.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adadelta.py:11 in public class `Adadelta`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adadelta.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adadelta.py:44 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
9) **File: torch/optim/adagrad.py**
```
Before:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
/content/pytorch/torch/optim/adagrad.py:100 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adagrad.py:201 in public function `adagrad`:
D202: No blank lines allowed after function docstring (found 1)
7
```
```
After:
/content/pytorch/torch/optim/adagrad.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adagrad.py:11 in public class `Adagrad`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adagrad.py:12 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adagrad.py:63 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adagrad.py:78 in public method `share_memory`:
D102: Missing docstring in public method
5
```
10) **File: torch/optim/adam.py**
```
Before:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adam.py:135 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D202: No blank lines allowed after function docstring (found 1)
/content/pytorch/torch/optim/adam.py:281 in public function `adam`:
D205: 1 blank line required between summary line and description (found 0)
7
```
```
After:
/content/pytorch/torch/optim/adam.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adam.py:14 in public class `Adam`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adam.py:15 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adam.py:65 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
11) **File: torch/optim/adamax.py**
```
Before:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamax.py:91 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamax.py:203 in public function `adamax`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamax.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamax.py:12 in public class `Adamax`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamax.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamax.py:47 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
12) **File: torch/optim/adamw.py**
```
Before:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/adamw.py:153 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/adamw.py:304 in public function `adamw`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/adamw.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/adamw.py:12 in public class `AdamW`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/adamw.py:13 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/adamw.py:73 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
13) **File: torch/optim/asgd.py**
```
Before:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
/content/pytorch/torch/optim/asgd.py:107 in public method `step`:
D401: First line should be in imperative mood (perhaps 'Perform', not 'Performs')
/content/pytorch/torch/optim/asgd.py:195 in public function `asgd`:
D202: No blank lines allowed after function docstring (found 1)
6
```
```
After:
/content/pytorch/torch/optim/asgd.py:1 at module level:
D100: Missing docstring in public module
/content/pytorch/torch/optim/asgd.py:17 in public class `ASGD`:
D101: Missing docstring in public class
/content/pytorch/torch/optim/asgd.py:18 in public method `__init__`:
D107: Missing docstring in __init__
/content/pytorch/torch/optim/asgd.py:52 in public method `__setstate__`:
D105: Missing docstring in magic method
4
```
Resolved docstring errors as listed. I initially changed in the main branch of forked repo which caused changes to appear in my PR to other issue. I have fixed that and hope this PR won't have any conflicts.
Kindly review @svekars @jbschlosser.
In case of any other issues please let me know. Thanks!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112964
Approved by: https://github.com/kit1980
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99