Commit Graph

28 Commits

Author SHA1 Message Date
ab647bd325 Add missing interfaces of torch.optim.swa_utils (#117036)
Add type hints for the function/class interfaces that appear in torch/optim/swa_utils.py but are missing in torch/optim/swa_utils.pyi.

- get_ema_multi_avg_fn
- get_swa_multi_avg_fn
- get_ema_avg_fn
- get_swa_avg_fn
- AveragedModel.__init__(multi_avg_fn)
- SWALR.get_lr

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117036
Approved by: https://github.com/janeyx99
2024-04-12 17:17:36 +00:00
4624afaa30 use reset_running_stats in swa_utils.update_bn (#103801)
the stat reset in `swa_utils.update_bn` already exists in `NormBase.reset_running_stats`, so use that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103801
Approved by: https://github.com/janeyx99
2023-06-23 01:17:13 +00:00
f73ff54f9a Use torch._foreach_lerp for SWA update (#103550)
Launch fewer kernels during a SWA update thanks to `torch._foreach_lerp`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103550
Approved by: https://github.com/janeyx99
2023-06-21 15:35:20 +00:00
6d2887cc06 Reland "Move tensor grouping to ATen" (#103912)
This is a reland of https://github.com/pytorch/pytorch/pull/100007 with a build fix for Windows debug builds.
`at::native::ParamsHash` only works on structs with standard layout, but `std::string` isn't one in Visual C++ debug builds, which one can easily verified by running something like:
```cpp
#define _DEBUG
#include <type_traits>
#include <string>
static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
```
If above conditon is not met, instead of printing a static_assert output, VC++ raises a very cryptic compilation errors,  see https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for more detail.

Also, using `std::hash` for string should result in a faster hash function.

(cherry picked from commit 74b7a6c75e698378882d30958908073407f97fb3)

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 5914771</samp>

This pull request introduces a new function `_group_tensors_by_device_and_dtype` that can group tensors by their device and dtype, and updates the `foreach` utilities and several optimizers to use this function. The goal is to improve the performance, readability, and compatibility of the code that handles tensors with different properties. The pull request also adds a test case and type annotations for the new function, and some error checks for the `fused` argument in Adam and AdamW.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103912
Approved by: https://github.com/janeyx99
2023-06-21 09:26:33 +00:00
0cb5bc3b04 Revert "Move tensor grouping to ATen (#100007)"
This reverts commit 74b7a6c75e698378882d30958908073407f97fb3.

Reverted https://github.com/pytorch/pytorch/pull/100007 on behalf of https://github.com/izaitsevfb due to Breaks internal builds, see D46629727 ([comment](https://github.com/pytorch/pytorch/pull/100007#issuecomment-1587861598))
2023-06-12 18:30:33 +00:00
900226f20a add multi swa support for custom device (#103297)
Fixes #ISSUE_NUMBER
add multi swa support for custom device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103297
Approved by: https://github.com/janeyx99
2023-06-10 10:01:08 +00:00
74b7a6c75e Move tensor grouping to ATen (#100007)
rel: #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100007
Approved by: https://github.com/janeyx99
2023-06-09 15:44:46 +00:00
45bf3f6216 Optimized EMA implementation (#94820)
This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.

This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.

Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):
    training_loop(model, opt)

    regular_eval_accuracy = evaluate(model)

    with opt.swap_ema_weights():
        ema_eval_accuracy = evaluate(model)
```

Here are some benchmarks (time per iteration) on various torchvision models:

|model|this PR iteration time                      |swa_utils.AveragedModel iteration time| iteration speedup                                      |
|-----|-----------------------------|-----------------------|---------------------------------------------|
|     |                             |                       |                                             |
|regnet_x_1_6gf|62.73                        |67.998                 |1.08                                         |
|regnet_x_3_2gf|101.75                       |109.422                |1.08                                         |
|regnet_x_400mf|25.13                        |32.005                 |1.27                                         |
|regnet_x_800mf|33.01                        |37.466                 |1.13                                         |
|regnet_x_8gf|128.13                       |134.868                |1.05                                         |
|regnet_y_16gf|252.91                       |261.292                |1.03                                         |
|regnet_y_1_6gf|72.14                        |84.22                  |1.17                                         |
|regnet_y_3_2gf|99.99                        |109.296                |1.09                                         |
|regnet_y_400mf|29.53                        |36.506                 |1.24                                         |
|regnet_y_800mf|37.82                        |43.634                 |1.15                                         |
|regnet_y_8gf|196.63                       |203.317                |1.03                                         |
|resnet101|128.80                       |137.434                |1.07                                         |
|resnet152|182.85                       |196.498                |1.07                                         |
|resnet18|29.06                        |29.975                 |1.03                                         |
|resnet34|50.73                        |53.443                 |1.05                                         |
|resnet50|76.88                        |80.602                 |1.05                                         |
|resnext101_32x8d|277.29                       |280.759                |1.01                                         |
|resnext101_64x4d|269.56                       |281.052                |1.04                                         |
|resnext50_32x4d|100.73                       |101.102                |1.00                                         |
|shufflenet_v2_x0_5|10.56                        |15.419                 |1.46                                         |
|shufflenet_v2_x1_0|13.11                        |18.525                 |1.41                                         |
|shufflenet_v2_x1_5|18.05                        |23.132                 |1.28                                         |
|shufflenet_v2_x2_0|25.04                        |30.008                 |1.20                                         |
|squeezenet1_1|14.26                        |14.325                 |1.00                                         |
|swin_b|264.52                       |274.613                |1.04                                         |
|swin_s|180.66                       |188.914                |1.05                                         |
|swin_t|108.62                       |112.632                |1.04                                         |
|swin_v2_s|220.29                       |231.153                |1.05                                         |
|swin_v2_t|127.27                       |133.586                |1.05                                         |
|vgg11|95.52                        |103.714                |1.09                                         |
|vgg11_bn|106.49                       |120.711                |1.13                                         |
|vgg13|132.94                       |147.063                |1.11                                         |
|vgg13_bn|149.73                       |165.256                |1.10                                         |
|vgg16|158.19                       |172.865                |1.09                                         |
|vgg16_bn|177.04                       |192.888                |1.09                                         |
|vgg19|184.76                       |194.194                |1.05                                         |
|vgg19_bn|203.30                       |213.334                |1.05                                         |
|vit_b_16|217.31                       |219.748                |1.01                                         |
|vit_b_32|69.47                        |75.692                 |1.09                                         |
|vit_l_32|223.20                       |258.487                |1.16                                         |
|wide_resnet101_2|267.38                       |279.836                |1.05                                         |
|wide_resnet50_2|145.06                       |154.918                |1.07                                         |

You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.

This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).

If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.

Credits to @szmigacz for the implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
2023-04-26 18:02:11 +00:00
e8b0f504e2 Fix unpicklable object in AveragedModel (#95979)
Fixes #95376

Don't store the callable `avg_fn`, instead test if `avg_fn` is None and call
the default impl if it's not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95979
Approved by: https://github.com/janeyx99
2023-03-12 05:13:22 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
0a69c50a46 Publicly expose _LRScheduler to LRScheduler (#88503)
Fixes #61232

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88503
Approved by: https://github.com/soulitzer
2022-11-07 21:15:10 +00:00
512a3a48e3 sync AveragedModel buffers when use_buffers=False (#84054)
Fixes #84053

As described in the issue, the AveragedModel will deep copy the model during initialization, which means that the buffers in the averaged model cannot be updated together with the model.

One solution is to make the buffers equal to the source model every time when calling `update_parameters`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84054
Approved by: https://github.com/samdow
2022-10-24 16:03:14 +00:00
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
bda04e9f5e Add __all__ for torch.optim and torch.nn.modules modules (#80237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237
Approved by: https://github.com/albanD
2022-06-24 21:34:10 +00:00
660d9ddef4 Fix SWALR doc string (#79836)
In `torch/optim/swa_utils.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79836
Approved by: https://github.com/albanD
2022-06-20 12:57:07 +00:00
313557a613 Add missing import (#72840)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72840

Reviewed By: H-Huang

Differential Revision: D34242612

Pulled By: albanD

fbshipit-source-id: 3dd34de96dbf1ae8f3c3ea45888d211d95862c49
(cherry picked from commit d2650ffa75dbba315daeb0e7cdf0fcb56f3584e1)
2022-02-15 19:43:54 +00:00
942a084c46 Remove state_dict from AveragedModel and use buffers instead (#71763)
Summary:
Fixes [https://github.com/pytorch/pytorch/issues/66686](https://github.com/pytorch/pytorch/issues/66686)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71763

Reviewed By: anjali411

Differential Revision: D33770907

Pulled By: prabhat00155

fbshipit-source-id: ee32f2cb8475c9add4e1a9a5d3d784ef95825efc
(cherry picked from commit a15898b072ae5234c76afa005ec492ed158c51aa)
2022-01-26 13:31:30 +00:00
a2e35e167b refactor: update f-string for swa.utils.py (#68718)
Summary:
_ Update some old-style formats to f-string, for whole and coherent consistency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68718

Reviewed By: jbschlosser

Differential Revision: D32593746

Pulled By: albanD

fbshipit-source-id: fcc17958f8af6a3260beca883bc1065f019dcf0e
2021-11-22 11:23:18 -08:00
c7748fc172 Added validation of mode parameter in AveragedModel (#65921)
Summary:
Discussion: https://github.com/pytorch/pytorch/pull/65495#issuecomment-930460469

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
2021-10-03 08:42:28 -07:00
2ea724b1fd Added option to update parameters using state_dict in AveragedModel (#65495)
Summary:
While implementing [EMA](https://github.com/pytorch/vision/pull/4381)(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](https://github.com/pytorch/vision/pull/4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: https://github.com/pytorch/vision/pull/4406#pullrequestreview-753734102

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
2021-09-28 03:34:49 -07:00
27135f86fd fix docstring default value of last_epoch for SWALR in torch/optim/… (#62799)
Summary:
…swa_utils

Fixes https://github.com/pytorch/pytorch/issues/62633

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62799

Reviewed By: zou3519

Differential Revision: D30131929

Pulled By: H-Huang

fbshipit-source-id: 741c077073bbe398492dff0761836acdbba7be78
2021-08-06 08:15:10 -07:00
8c798e0622 Forbid trailing whitespace (#53406)
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857

These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
  - `GLOSSARY.md`
  - `aten/src/ATen/core/op_registration/README.md`
  - `scripts/README.md`
  - `torch/csrc/jit/codegen/fuser/README.md`

The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```

I looked over the auto-generated changes and didn't see anything that looked problematic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406

Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377

This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348

Reviewed By: walterddr, seemethere

Differential Revision: D26856620

Pulled By: samestep

fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
2021-03-05 17:22:55 -08:00
c871abecf5 Added torch.no_grad() to update_bn (#52654)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/52055

This fixes the **out of memory error** while using update_bn in **SWA**, by not allocating memory for backpropagation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52654

Reviewed By: malfet

Differential Revision: D26620077

Pulled By: albanD

fbshipit-source-id: 890b5a78ba9c1a148f3ab7c63472a73d8f6412a4
2021-02-25 11:35:38 -08:00
e6779d4357 [*.py] Rename "Arguments:" to "Args:" (#49736)
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.

```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
    printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args:      1095
Arguments: 0336
```

It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:

  - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)

  - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)

  - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)

Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.

PS: For related PRs, see tensorflow/tensorflow/pull/45420

PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736

Reviewed By: albanD

Differential Revision: D25710534

Pulled By: soumith

fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00
09173ae65e Allow zero annealing epochs (#47579)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/47578.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47579

Reviewed By: H-Huang

Differential Revision: D25429403

Pulled By: vincentqb

fbshipit-source-id: c42fbcd71b46e07c672a1e9661468848ac16de38
2020-12-16 14:09:43 -08:00
9600ed9af3 typo fixes (#41632)
Summary:
typo fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41632

Reviewed By: ezyang

Differential Revision: D22617827

Pulled By: mrshenli

fbshipit-source-id: c2bfcb7cc36913a8dd32f13fc9adc3aa0a9b682f
2020-07-20 07:23:00 -07:00
22ac071d9a Add SWA to PyTorch mainline (#35032)
Summary:
This PR is based on the issue https://github.com/pytorch/pytorch/issues/29994#issue-524418771 and the discussion in the previous version of the PR https://github.com/pytorch/pytorch/pull/30559. Specifically, I followed the interface outlined in this [comment](https://github.com/pytorch/pytorch/pull/30559#issuecomment-574864768).

## Structure
- `torch/optim/swa_utils.py` contains the implementation of  `AveragedModel` class, `SWALR` learning rate scheduler and `update_bn` utility
- `test/test_optim.py` contains unit tests for the three components of SWA
- `torch/optim/swa_utils.pyi` describes the interface of `torch/optim/swa_utils.py`

The new implementation consists of
- `AveragedModel` class; this class creates a copy of a given model and allows to compute running averages of the parameters.
- `SWALR` learning rate scheduler; after a certain number of epochs switches to a constant learning rate; this scheduler is supposed to be chained with other schedulers.
- `update_bn` utility; updates the Batch Normalization activation statistics for a given model and dataloader; this utility is meant to be applied to `AveragedModel` instances.

For `update_bn` I simplified the implementation compared to the [original PR](https://github.com/pytorch/pytorch/pull/30559) according to the sugestions by vadimkantorov.

## Example
```python
loader, optimizer, model = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
# You can use custom averaging functions with `avg_fun` parameter
ema_avg = lambda p_avg, p, n_avg: 0.1 * p_avg + 0.9 * p
ema_model = torch.optim.swa_utils.AveragedModel(model,
                                    avg_function=ema_avg)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                    T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, start_epoch=swa_start, swa_lr=0.05)

for i in range(300):
     for input, target in loader:
         optimizer.zero_grad()
         loss_fn(model(input), target).backward()
         optimizer.step()
         scheduler.step()
         swa_scheduler.step()

     if i > swa_start:
         swa_model.update_parameters(model)

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
```

UPDATED:
```python3
loader, optimizer, model, loss_fn = ...
swa_model = torch.optim.swa_utils.AveragedModel(model)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
swa_start = 160
swa_scheduler = SWALR(optimizer, swa_lr=0.05)
for i in range(300):
     for input, target in loader:
         optimizer.zero_grad()
         loss_fn(model(input), target).backward()
         optimizer.step()
     if i > swa_start:
         swa_model.update_parameters(model)
         swa_scheduler.step()
     else:
         scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
```

Fixes https://github.com/pytorch/pytorch/issues/29994
cc soumith vincentqb andrewgordonwilson vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35032

Differential Revision: D21079606

Pulled By: vincentqb

fbshipit-source-id: e07f5e821f72ada63789814c2dcbdc31f0160c37
2020-04-27 07:42:19 -07:00