This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.
This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.
Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
```
Here are some benchmarks (time per iteration) on various torchvision models:
|model|this PR iteration time |swa_utils.AveragedModel iteration time| iteration speedup |
|-----|-----------------------------|-----------------------|---------------------------------------------|
| | | | |
|regnet_x_1_6gf|62.73 |67.998 |1.08 |
|regnet_x_3_2gf|101.75 |109.422 |1.08 |
|regnet_x_400mf|25.13 |32.005 |1.27 |
|regnet_x_800mf|33.01 |37.466 |1.13 |
|regnet_x_8gf|128.13 |134.868 |1.05 |
|regnet_y_16gf|252.91 |261.292 |1.03 |
|regnet_y_1_6gf|72.14 |84.22 |1.17 |
|regnet_y_3_2gf|99.99 |109.296 |1.09 |
|regnet_y_400mf|29.53 |36.506 |1.24 |
|regnet_y_800mf|37.82 |43.634 |1.15 |
|regnet_y_8gf|196.63 |203.317 |1.03 |
|resnet101|128.80 |137.434 |1.07 |
|resnet152|182.85 |196.498 |1.07 |
|resnet18|29.06 |29.975 |1.03 |
|resnet34|50.73 |53.443 |1.05 |
|resnet50|76.88 |80.602 |1.05 |
|resnext101_32x8d|277.29 |280.759 |1.01 |
|resnext101_64x4d|269.56 |281.052 |1.04 |
|resnext50_32x4d|100.73 |101.102 |1.00 |
|shufflenet_v2_x0_5|10.56 |15.419 |1.46 |
|shufflenet_v2_x1_0|13.11 |18.525 |1.41 |
|shufflenet_v2_x1_5|18.05 |23.132 |1.28 |
|shufflenet_v2_x2_0|25.04 |30.008 |1.20 |
|squeezenet1_1|14.26 |14.325 |1.00 |
|swin_b|264.52 |274.613 |1.04 |
|swin_s|180.66 |188.914 |1.05 |
|swin_t|108.62 |112.632 |1.04 |
|swin_v2_s|220.29 |231.153 |1.05 |
|swin_v2_t|127.27 |133.586 |1.05 |
|vgg11|95.52 |103.714 |1.09 |
|vgg11_bn|106.49 |120.711 |1.13 |
|vgg13|132.94 |147.063 |1.11 |
|vgg13_bn|149.73 |165.256 |1.10 |
|vgg16|158.19 |172.865 |1.09 |
|vgg16_bn|177.04 |192.888 |1.09 |
|vgg19|184.76 |194.194 |1.05 |
|vgg19_bn|203.30 |213.334 |1.05 |
|vit_b_16|217.31 |219.748 |1.01 |
|vit_b_32|69.47 |75.692 |1.09 |
|vit_l_32|223.20 |258.487 |1.16 |
|wide_resnet101_2|267.38 |279.836 |1.05 |
|wide_resnet50_2|145.06 |154.918 |1.07 |
You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.
This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.
Credits to @szmigacz for the implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
Fixes#95781.
The cause seems to be that the current implementation doesn't correctly pass `found_inf` when `grad_scale` is `None`. Therefore parameters can get mistakenly updated by gradients whose some elements are invalid, i.e. nan or inf.
Related #94060
I forgot about this wrong handling after #94344
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95847
Approved by: https://github.com/janeyx99
Big OOP correction continued. Also added a test this time to verify the defaulting was as expected.
The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95820
Approved by: https://github.com/albanD
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
Old behavior would have adadelta foreach sending tensors to the slow path if they were not all the same dtype nor on the same device.
This PR adds grouping for adadelta optimizer so that it would run foreach in batches, allowing more users to benefit from foreach perf.
Of course, we should ensure that the new implementation works, so there are new tests to ensure this behavior is not broken.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92048
Approved by: https://github.com/albanD
I realized test_fused_optimizers used a helper that was written for foreach, so we were not testing fused at all. This PR fixes that test so we actually test fused adam.
The explicitly adding fused=False is to set the stage for my later changes (but should be a no-op here).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91228
Approved by: https://github.com/albanD, https://github.com/soulitzer
Fixes#84053
As described in the issue, the AveragedModel will deep copy the model during initialization, which means that the buffers in the averaged model cannot be updated together with the model.
One solution is to make the buffers equal to the source model every time when calling `update_parameters`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84054
Approved by: https://github.com/samdow
Hi, we noticed in our team that by using CyclicLR, there is a problem with memory clearance on GPU (probably it will be the case without the GPU as well, but that was our use case) After initializing CyclicLR, GPU memory is not cleared even after the model, optimizer and scheduler are out of scope (e.g. reference count is zero). This is because `__init__` method inside `CyclicLR` creates reference to its own methods and it will not get removed until `gc.collect()` is called manually. This is a problem if people want to test multiple models in one run of a script, after testing the first model, second one will fail on `CUDA out of memory error` because the first one is not cleared from the memory.
I propose a simple fix by using `weakref`, similarly as in `_LRScheduler` base class, but if you have any comments I am happy to change it.
Here is the code to reproduce the bug:
```
import torch
import weakref
from transformers import DetrForObjectDetection
class X:
def __init__(self, optimizer):
self.optimizer = optimizer
# Will cause cyclic reference.
self.func = self.dummy
# Will work as expected, memory cleared after instance count is zero.
# self.func = weakref.WeakMethod(self.dummy)
def dummy(self, x):
return 1.
def test():
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
model.to('cuda')
optimizer = torch.optim.Adam(model.parameters())
x = X(optimizer)
test()
print(f'{torch.cuda.memory_reserved()}, {torch.cuda.memory_allocated()}') # Should print (<some memory>, 0), but with cyclic reference, it will print (<some memory>, <some memory>).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85462
Approved by: https://github.com/albanD
This is to improve the performance for hybrid sparse coo tensor on CPU path. This case is appeared at the DLRM terabyte test.
With this fix, according to the previous performance test data, it got ~10x performance improvement on DLRM execution.
without this, the DLRM will run as
Finished training it 100/1000 of epoch 0, 2969.25 ms/it, loss 0.220505, accuracy 0.000 %
with this, the DLRM will run as
Finished training it 100/1000 of epoch 0, 270.71 ms/it, loss 0.220505, accuracy 0.000 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23057
Approved by: https://github.com/VitalyFedyunin, https://github.com/malfet