Optimized EMA implementation (#94820)

This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies.

This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes.

Example usage:
```
model = Model().to(device)
opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):
    training_loop(model, opt)

    regular_eval_accuracy = evaluate(model)

    with opt.swap_ema_weights():
        ema_eval_accuracy = evaluate(model)
```

Here are some benchmarks (time per iteration) on various torchvision models:

|model|this PR iteration time                      |swa_utils.AveragedModel iteration time| iteration speedup                                      |
|-----|-----------------------------|-----------------------|---------------------------------------------|
|     |                             |                       |                                             |
|regnet_x_1_6gf|62.73                        |67.998                 |1.08                                         |
|regnet_x_3_2gf|101.75                       |109.422                |1.08                                         |
|regnet_x_400mf|25.13                        |32.005                 |1.27                                         |
|regnet_x_800mf|33.01                        |37.466                 |1.13                                         |
|regnet_x_8gf|128.13                       |134.868                |1.05                                         |
|regnet_y_16gf|252.91                       |261.292                |1.03                                         |
|regnet_y_1_6gf|72.14                        |84.22                  |1.17                                         |
|regnet_y_3_2gf|99.99                        |109.296                |1.09                                         |
|regnet_y_400mf|29.53                        |36.506                 |1.24                                         |
|regnet_y_800mf|37.82                        |43.634                 |1.15                                         |
|regnet_y_8gf|196.63                       |203.317                |1.03                                         |
|resnet101|128.80                       |137.434                |1.07                                         |
|resnet152|182.85                       |196.498                |1.07                                         |
|resnet18|29.06                        |29.975                 |1.03                                         |
|resnet34|50.73                        |53.443                 |1.05                                         |
|resnet50|76.88                        |80.602                 |1.05                                         |
|resnext101_32x8d|277.29                       |280.759                |1.01                                         |
|resnext101_64x4d|269.56                       |281.052                |1.04                                         |
|resnext50_32x4d|100.73                       |101.102                |1.00                                         |
|shufflenet_v2_x0_5|10.56                        |15.419                 |1.46                                         |
|shufflenet_v2_x1_0|13.11                        |18.525                 |1.41                                         |
|shufflenet_v2_x1_5|18.05                        |23.132                 |1.28                                         |
|shufflenet_v2_x2_0|25.04                        |30.008                 |1.20                                         |
|squeezenet1_1|14.26                        |14.325                 |1.00                                         |
|swin_b|264.52                       |274.613                |1.04                                         |
|swin_s|180.66                       |188.914                |1.05                                         |
|swin_t|108.62                       |112.632                |1.04                                         |
|swin_v2_s|220.29                       |231.153                |1.05                                         |
|swin_v2_t|127.27                       |133.586                |1.05                                         |
|vgg11|95.52                        |103.714                |1.09                                         |
|vgg11_bn|106.49                       |120.711                |1.13                                         |
|vgg13|132.94                       |147.063                |1.11                                         |
|vgg13_bn|149.73                       |165.256                |1.10                                         |
|vgg16|158.19                       |172.865                |1.09                                         |
|vgg16_bn|177.04                       |192.888                |1.09                                         |
|vgg19|184.76                       |194.194                |1.05                                         |
|vgg19_bn|203.30                       |213.334                |1.05                                         |
|vit_b_16|217.31                       |219.748                |1.01                                         |
|vit_b_32|69.47                        |75.692                 |1.09                                         |
|vit_l_32|223.20                       |258.487                |1.16                                         |
|wide_resnet101_2|267.38                       |279.836                |1.05                                         |
|wide_resnet50_2|145.06                       |154.918                |1.07                                         |

You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow.

This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).

If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests.

Credits to @szmigacz for the implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820
Approved by: https://github.com/janeyx99
This commit is contained in:
milesial
2023-04-26 18:02:11 +00:00
committed by PyTorch MergeBot
parent c6ab4ff35c
commit 45bf3f6216
3 changed files with 244 additions and 129 deletions

View File

@ -254,32 +254,73 @@ algorithms.
lr_scheduler.OneCycleLR
lr_scheduler.CosineAnnealingWarmRestarts
Stochastic Weight Averaging
---------------------------
Weight Averaging (SWA and EMA)
------------------------------
:mod:`torch.optim.swa_utils` implements Stochastic Weight Averaging (SWA). In particular,
:class:`torch.optim.swa_utils.AveragedModel` class implements SWA models,
:mod:`torch.optim.swa_utils` implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). In particular,
the :class:`torch.optim.swa_utils.AveragedModel` class implements SWA and EMA models,
:class:`torch.optim.swa_utils.SWALR` implements the SWA learning rate scheduler and
:func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA batch
:func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA/EMA batch
normalization statistics at the end of training.
SWA has been proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_.
EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed. It is a variation of `Polyak averaging`_, but using exponential weights instead of equal weights across iterations.
.. _`Averaging Weights Leads to Wider Optima and Better Generalization`: https://arxiv.org/abs/1803.05407
.. _`Polyak averaging`: https://paperswithcode.com/method/polyak-averaging
Constructing averaged models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
`AveragedModel` class serves to compute the weights of the SWA model. You can create an
averaged model by running:
The `AveragedModel` class serves to compute the weights of the SWA or EMA model.
>>> swa_model = AveragedModel(model)
You can create an SWA averaged model by running:
Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``swa_model``
>>> averaged_model = AveragedModel(model)
EMA models are constructed by specifying the ``multi_avg_fn`` argument as follows:
>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to ``get_ema_multi_avg_fn``, the default is 0.999.
``get_ema_multi_avg_fn`` returns a function that applies the following EMA equation to the weights:
.. math:: W^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t
where alpha is the EMA decay.
Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``averaged_model``
will keep track of the running averages of the parameters of the ``model``. To update these
averages, you can use the :func:`update_parameters` function:
averages, you should use the :func:`update_parameters` function after the `optimizer.step()`:
>>> swa_model.update_parameters(model)
>>> averaged_model.update_parameters(model)
For SWA and EMA, this call is usually done right after the optimizer ``step()``. In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training.
Custom averaging strategies
^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of
the parameters that you provide, but you can also use custom averaging functions with the
``avg_fn`` or ``multi_avg_fn`` parameters:
- ``avg_fn`` allows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter.
- ``multi_avg_fn`` allows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using the ``torch._foreach*`` functions. This function must update the averaged parameters in-place.
In the following example ``ema_model`` computes an exponential moving average using the ``avg_fn`` parameter:
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
In the following example ``ema_model`` computes an exponential moving average using the more efficient ``multi_avg_fn`` parameter:
>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))
SWA learning rate schedules
@ -315,22 +356,10 @@ statistics for each batch normalization layer in the model.
``swa_model`` by doing a forward pass with the ``swa_model`` on each element of the dataset.
Custom averaging strategies
^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of
the parameters that you provide, but you can also use custom averaging functions with the
``avg_fn`` parameter. In the following example ``ema_model`` computes an exponential moving average.
Example:
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
Putting it all together
^^^^^^^^^^^^^^^^^^^^^^^
Putting it all together: SWA
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the example below, ``swa_model`` is the SWA model that accumulates the averages of the weights.
We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule
@ -357,3 +386,26 @@ and start to collect SWA averages of the parameters at epoch 160:
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)
Putting it all together: EMA
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In the example below, ``ema_model`` is the EMA model that accumulates the exponentially-decayed averages of the weights with a decay rate of 0.999.
We train the model for a total of 300 epochs and start to collect EMA averages immediately.
>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

View File

@ -33,7 +33,7 @@ from torch.optim.lr_scheduler import (
PolynomialLR,
EPOCH_DEPRECATION_WARNING,
)
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn, get_swa_multi_avg_fn, get_ema_multi_avg_fn
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
@ -3997,7 +3997,7 @@ class SWATestCNN(torch.nn.Module):
class TestSWAUtils(TestCase):
def _test_averaged_model(self, net_device, swa_device):
def _test_averaged_model(self, net_device, swa_device, ema):
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.ReLU(),
@ -4010,32 +4010,48 @@ class TestSWAUtils(TestCase):
torch.nn.Linear(5, 10),
).to(net_device)
averaged_dnn = AveragedModel(dnn, device=swa_device)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema)
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
# Check that AveragedModel is on the correct device
self.assertTrue(p_swa.device == swa_device)
self.assertTrue(p.device == net_device)
self.assertTrue(p_avg.device == net_device)
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
def test_averaged_model_all_devices(self):
def _run_averaged_steps(self, dnn, swa_device, ema):
ema_decay = 0.999
if ema:
averaged_dnn = AveragedModel(dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay))
else:
averaged_dnn = AveragedModel(dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn())
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if ema:
p_avg += p.detach() * ema_decay ** (n_updates - i - 1) * ((1 - ema_decay) if i > 0 else 1.0)
else:
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)
return averaged_params, averaged_dnn
@parametrize("ema", [True, False])
def test_averaged_model_all_devices(self, ema):
cpu = torch.device("cpu")
self._test_averaged_model(cpu, cpu)
self._test_averaged_model(cpu, cpu, ema)
if torch.cuda.is_available():
cuda = torch.device(0)
self._test_averaged_model(cuda, cpu)
self._test_averaged_model(cpu, cuda)
self._test_averaged_model(cuda, cuda)
self._test_averaged_model(cuda, cpu, ema)
self._test_averaged_model(cpu, cuda, ema)
self._test_averaged_model(cuda, cuda, ema)
def test_averaged_model_mixed_device(self):
@parametrize("ema", [True, False])
def test_averaged_model_mixed_device(self, ema):
if not torch.cuda.is_available():
return
dnn = torch.nn.Sequential(
@ -4043,14 +4059,8 @@ class TestSWAUtils(TestCase):
)
dnn[0].cuda()
dnn[1].cpu()
averaged_dnn = AveragedModel(dnn)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
p_avg += p.detach() / n_updates
averaged_dnn.update_parameters(dnn)
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema)
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
@ -4082,62 +4092,36 @@ class TestSWAUtils(TestCase):
averaged_dnn = AveragedModel(dnn)
pickle.dumps(averaged_dnn)
def test_averaged_model_exponential(self):
# Test AveragedModel with EMA as avg_fn
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Linear(5, 10),
)
alpha = 0.9
def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
for p, p_avg in zip(dnn.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))
if i == 0:
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * alpha + p * (1 - alpha)).clone()
)
for b in dnn.buffers():
if b.size() != torch.Size([]):
b.detach_().add_(torch.randn_like(b))
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
self.assertEqual(b_avg, b_swa)
def test_averaged_model_exponential_buffers(self):
@parametrize("use_multi_avg_fn", [True, False])
@parametrize("use_buffers", [True, False])
def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers):
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3),
torch.nn.BatchNorm2d(5, momentum=0.3),
torch.nn.Linear(5, 10),
)
alpha = 0.9
decay = 0.9
def avg_fn(p_avg, p, n_avg):
return alpha * p_avg + (1 - alpha) * p
if use_multi_avg_fn:
averaged_dnn = AveragedModel(dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers)
else:
def avg_fn(p_avg, p, n_avg):
return decay * p_avg + (1 - decay) * p
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)
if use_buffers:
dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
else:
dnn_params = list(dnn.parameters())
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
averaged_params = [
torch.zeros_like(param)
for param in dnn_params
if param.size() != torch.Size([])
]
n_updates = 10
for i in range(n_updates):
updated_averaged_params = []
@ -4149,18 +4133,24 @@ class TestSWAUtils(TestCase):
updated_averaged_params.append(p.clone())
else:
updated_averaged_params.append(
(p_avg * alpha + p * (1 - alpha)).clone()
(p_avg * decay + p * (1 - decay)).clone()
)
averaged_dnn.update_parameters(dnn)
averaged_params = updated_averaged_params
for p_avg, p_swa in zip(
averaged_params,
itertools.chain(
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
),
):
self.assertEqual(p_avg, p_swa)
if use_buffers:
for p_avg, p_swa in zip(
averaged_params,
itertools.chain(
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
),
):
self.assertEqual(p_avg, p_swa)
else:
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
self.assertEqual(p_avg, p_swa)
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
self.assertEqual(b_avg, b_swa)
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
@ -4265,6 +4255,7 @@ class TestSWAUtils(TestCase):
instantiate_parametrized_tests(TestLRScheduler)
instantiate_parametrized_tests(TestSWAUtils)
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):

View File

@ -7,30 +7,87 @@ import torch
from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
__all__ = ['AveragedModel', 'update_bn', 'SWALR']
__all__ = [
'AveragedModel',
'update_bn',
'SWALR',
'get_ema_multi_avg_fn',
'get_swa_multi_avg_fn',
'get_ema_avg_fn',
'get_swa_avg_fn'
]
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
def get_ema_multi_avg_fn(decay=0.999):
@torch.no_grad()
def ema_update(ema_param_list, current_param_list, _):
# foreach lerp only handles float and complex
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
else:
for p_ema, p_model in zip(ema_param_list, current_param_list):
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
return ema_update
def get_swa_multi_avg_fn():
@torch.no_grad()
def swa_update(averaged_param_list, current_param_list, num_averaged):
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
return swa_update
def get_ema_avg_fn(decay=0.999):
@torch.no_grad()
def ema_update(ema_param, current_param, num_averaged):
return decay * ema_param + (1 - decay) * current_param
return ema_update
def get_swa_avg_fn():
@torch.no_grad()
def swa_update(averaged_param, current_param, num_averaged):
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
return swa_update
class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA).
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
Exponential Moving Average (EMA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
Exponential Moving Average is a variation of `Polyak averaging`_,
but using exponential weights instead of equal weights across iterations.
AveragedModel class creates a copy of the provided module :attr:`model`
on the device :attr:`device` and allows to compute running averages of the
parameters of the :attr:`model`.
Args:
model (torch.nn.Module): model to use with SWA
model (torch.nn.Module): model to use with SWA/EMA
device (torch.device, optional): if provided, the averaged model will be
stored on the :attr:`device`
avg_fn (function, optional): the averaging function used to update
parameters; the function must take in the current value of the
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
parameter, and the number of models already averaged; if None,
an equally weighted average is used (default: None)
multi_avg_fn (function, optional): the averaging function used to update
parameters inplace; the function must take in the current values of the
:class:`AveragedModel` parameters as a list, the current values of :attr:`model`
parameters as a list, and the number of models already averaged; if None,
an equally weighted average is used (default: None)
use_buffers (bool): if ``True``, it will compute running averages for
both the parameters and the buffers of the model. (default: ``False``)
@ -56,19 +113,18 @@ class AveragedModel(Module):
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
You can also use custom averaging functions with `avg_fn` parameter.
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
If no averaging function is provided, the default is to compute
equally-weighted average of the weights.
equally-weighted average of the weights (SWA).
Example:
>>> # xdoctest: +SKIP("undefined variables")
>>> # Compute exponential moving averages of the weights and buffers
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: (
... 0.1 * averaged_model_parameter + 0.9 * model_parameter)
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True)
>>> ema_model = torch.optim.swa_utils.AveragedModel(model,
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
.. note::
When using SWA with models containing Batch Normalization you may
When using SWA/EMA with models containing Batch Normalization you may
need to update the activation statistics for Batch Normalization.
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
or by setting :attr:`use_buffers` to `True`. The first approach updates the
@ -79,7 +135,7 @@ class AveragedModel(Module):
approach yields the best results in your problem.
.. note::
:attr:`avg_fn` is not saved in the :meth:`state_dict` of the model.
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
.. note::
When :meth:`update_parameters` is called for the first time (i.e.
@ -98,15 +154,19 @@ class AveragedModel(Module):
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
Generalizes Well:
https://arxiv.org/abs/2001.02312
.. _Polyak averaging:
https://paperswithcode.com/method/polyak-averaging
"""
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
super().__init__()
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
self.avg_fn = avg_fn
self.multi_avg_fn = multi_avg_fn
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
@ -121,28 +181,40 @@ class AveragedModel(Module):
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers else model.parameters()
)
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
self_param_detached = []
model_param_detached = []
for p_averaged, p_model in zip(self_param, model_param):
device = p_averaged.device
p_model_ = p_model.detach().to(device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
p_averaged.detach().copy_(p_model_)
if self.n_averaged > 0:
if self.multi_avg_fn is not None or self.avg_fn is None:
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
for ((device, _), [self_params, model_params]) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
elif device.type == 'cuda':
multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
else:
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params):
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
if self.avg_fn is None:
p_swa.detach().copy_(
p_swa.detach()
+ (p_model_ - p_swa.detach()) / (self.n_averaged.to(device) + 1)
)
else:
p_swa.detach().copy_(
self.avg_fn(
p_swa.detach(), p_model_, self.n_averaged.to(device)
)
)
for p_averaged, p_model in zip(self_param_detached, model_param_detached):
n_averaged = self.n_averaged.to(p_averaged.device)
p_averaged.detach().copy_(self.avg_fn(p_averaged.detach(), p_model, n_averaged))
if not self.use_buffers:
# If not apply running averages to the buffers,
# keep the buffers in sync with the source model.
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
b_swa.detach().copy_(b_model.detach().to(device))
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
self.n_averaged += 1