1. Prevents unintended aliasing of `self._last_lr`/`get_last_lr(...)` with `group["lr"]` when `group["lr"]` is a tensor.
2. Prevents unintended aliasing of `LRScheduler.base_lrs` with the `group["initial_lr"]`s.
3. Updates `test/optim/test_lrscheduler.py` to test tensor LRs.
4. Changes type annotations for `_last_lr`, `get_last_lr()`, `base_lrs`, `get_lr()`, and `_get_closed_form_lr()` from `list[float]` to `list[float | Tensor]`; adds documentation.
Fixes#163103
LR schedulers can behave in unexpected ways when using a tensor LR due to patterns like this:
```python
self._last_lr: list[float] = [group["lr"] for group in self.optimizer.param_groups]
```
This PR adds a helper to address this:
```python
def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]:
"""Create a list containing group[key] for each optimizer param_group.
Prevents aliasing when group[key] could be a Tensor.
Raises a KeyError when group[key] does not exist.
"""
return [
group[key].clone() if isinstance(group[key], Tensor) else group[key]
for group in optimizer.param_groups
]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163120
Approved by: https://github.com/janeyx99
Fixes#163105
Note that the new `SWALR.load_state_dict` is **not backwards compatible**:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
self._set_anneal_func(self._anneal_strategy)
```
If we'd like to maintain compatibility with old state_dicts (loaded with `weights_only=False`), we could use something along these lines:
```python
@override
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the scheduler's state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
anneal_func = state_dict.pop("anneal_func", None)
strategy = state_dict.get("_anneal_strategy")
self.__dict__.update(state_dict)
if anneal_func is not None:
state_dict["anneal_func"] = anneal_func
if strategy is None:
if anneal_func == self._linear_anneal:
strategy = "linear"
elif anneal_func == self._cosine_anneal:
strategy = "cos"
if strategy is None:
strategy = getattr(self, "_anneal_strategy", "cos")
self._set_anneal_func(strategy)
```
But given the fact that loading an `SWALR` state_dict before this PR would have caused an error, this seems okay. A GitHub/Google search for `SWALR.load_state_dict` had no results. Happy to change if not, or add a warning just in case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163122
Approved by: https://github.com/janeyx99
Summary:
The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.
After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).
Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI
Rollback Plan:
Differential Revision: D78074709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158017
Approved by: https://github.com/janeyx99
Over time, a large number of the existing type ignores have become irrelevant/unused/dead as a result of improvements in annotations and type checking.
Having these `# type: ignore` linger around is not ideal for two reasons:
- They are verbose/ugly syntatically.
- They could hide genuine bugs in the future, if a refactoring would actually introduce a bug but it gets hidden by the ignore.
I'm counting over 1500 unused ignores already. This is a first PR that removes some of them. Note that I haven't touched type ignores that looked "conditional" like the import challenge mentioned in https://github.com/pytorch/pytorch/pull/60006#issuecomment-2480604728. I will address these at a later point, and eventually would enable `warn_unused_ignores = True` in the mypy configuration as discussed in that comment to prevent accumulating more dead ignores going forward.
This PR should have no effect on runtime at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142325
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html)
```
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.
```
An inspection of `swa_utils.py` indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).
Fixes https://github.com/pytorch/pytorch/issues/133772
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773
Approved by: https://github.com/janeyx99
Link various classes and functions of the `optim.swa.util` to make doc content accessible from the `torch.optim` doc.
Currently, if you click the link,
https://pytorch.org/docs/stable/optim.html#module-torch.optim.swa_utils it goes to a blank, bottom of the page section of `torch.optim`.
Also,
`torch.optim.swa_utils.AveragedModel` and `torch.optim.swa_utils.SWALR` classes as well as `torch.optim.swa_utils.update_bn()` and `optim.swa_utils.get_ema_multi_avg_fn` are not linked to doc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133393
Approved by: https://github.com/janeyx99
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors, see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.
Also, using `std::hash` for string should result in a faster hash function.
(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)
<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>
This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.
This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.
Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
```
Here are some benchmarks (time per iteration) on various torchvision models:
|model|this PR iteration time |swa_utils.AveragedModel iteration time| iteration speedup |
|-----|-----------------------------|-----------------------|---------------------------------------------|
| | | | |
|regnet_x_1_6gf|62.73 |67.998 |1.08 |
|regnet_x_3_2gf|101.75 |109.422 |1.08 |
|regnet_x_400mf|25.13 |32.005 |1.27 |
|regnet_x_800mf|33.01 |37.466 |1.13 |
|regnet_x_8gf|128.13 |134.868 |1.05 |
|regnet_y_16gf|252.91 |261.292 |1.03 |
|regnet_y_1_6gf|72.14 |84.22 |1.17 |
|regnet_y_3_2gf|99.99 |109.296 |1.09 |
|regnet_y_400mf|29.53 |36.506 |1.24 |
|regnet_y_800mf|37.82 |43.634 |1.15 |
|regnet_y_8gf|196.63 |203.317 |1.03 |
|resnet101|128.80 |137.434 |1.07 |
|resnet152|182.85 |196.498 |1.07 |
|resnet18|29.06 |29.975 |1.03 |
|resnet34|50.73 |53.443 |1.05 |
|resnet50|76.88 |80.602 |1.05 |
|resnext101_32x8d|277.29 |280.759 |1.01 |
|resnext101_64x4d|269.56 |281.052 |1.04 |
|resnext50_32x4d|100.73 |101.102 |1.00 |
|shufflenet_v2_x0_5|10.56 |15.419 |1.46 |
|shufflenet_v2_x1_0|13.11 |18.525 |1.41 |
|shufflenet_v2_x1_5|18.05 |23.132 |1.28 |
|shufflenet_v2_x2_0|25.04 |30.008 |1.20 |
|squeezenet1_1|14.26 |14.325 |1.00 |
|swin_b|264.52 |274.613 |1.04 |
|swin_s|180.66 |188.914 |1.05 |
|swin_t|108.62 |112.632 |1.04 |
|swin_v2_s|220.29 |231.153 |1.05 |
|swin_v2_t|127.27 |133.586 |1.05 |
|vgg11|95.52 |103.714 |1.09 |
|vgg11_bn|106.49 |120.711 |1.13 |
|vgg13|132.94 |147.063 |1.11 |
|vgg13_bn|149.73 |165.256 |1.10 |
|vgg16|158.19 |172.865 |1.09 |
|vgg16_bn|177.04 |192.888 |1.09 |
|vgg19|184.76 |194.194 |1.05 |
|vgg19_bn|203.30 |213.334 |1.05 |
|vit_b_16|217.31 |219.748 |1.01 |
|vit_b_32|69.47 |75.692 |1.09 |
|vit_l_32|223.20 |258.487 |1.16 |
|wide_resnet101_2|267.38 |279.836 |1.05 |
|wide_resnet50_2|145.06 |154.918 |1.07 |
You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.
This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.
Credits to @szmigacz for the implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
Fixes#84053
As described in the issue, the AveragedModel will deep copy the model during initialization, which means that the buffers in the averaged model cannot be updated together with the model.
One solution is to make the buffers equal to the source model every time when calling `update_parameters`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84054
Approved by: https://github.com/samdow
This is a new version of #15648 based on the latest master branch.
Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.
In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)
Fixes https://github.com/pytorch/pytorch/issues/71105
@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857
These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
- `GLOSSARY.md`
- `aten/src/ATen/core/op_registration/README.md`
- `scripts/README.md`
- `torch/csrc/jit/codegen/fuser/README.md`
The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```
I looked over the auto-generated changes and didn't see anything that looked problematic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406
Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377
This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348
Reviewed By: walterddr, seemethere
Differential Revision: D26856620
Pulled By: samestep
fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52055
This fixes the **out of memory error** while using update_bn in **SWA**, by not allocating memory for backpropagation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52654
Reviewed By: malfet
Differential Revision: D26620077
Pulled By: albanD
fbshipit-source-id: 890b5a78ba9c1a148f3ab7c63472a73d8f6412a4