mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Optimized EMA implementation (#94820)
This PR proposes an optimized way to do Exponential Moving Average (EMA), which is faster than the current way using `swa_utils.AveragedModel` described in https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies. This implementation is asynchronous, and is built as an optimizer wrapper so that the EMA weight update happens without any additional CPU/GPU sync, just after optimizer steps, and with limited code changes. Example usage: ``` model = Model().to(device) opt = torch.optim.Adam(model.parameters()) opt = EMAOptimizer(opt, device, 0.9999) for epoch in range(epochs): training_loop(model, opt) regular_eval_accuracy = evaluate(model) with opt.swap_ema_weights(): ema_eval_accuracy = evaluate(model) ``` Here are some benchmarks (time per iteration) on various torchvision models: |model|this PR iteration time |swa_utils.AveragedModel iteration time| iteration speedup | |-----|-----------------------------|-----------------------|---------------------------------------------| | | | | | |regnet_x_1_6gf|62.73 |67.998 |1.08 | |regnet_x_3_2gf|101.75 |109.422 |1.08 | |regnet_x_400mf|25.13 |32.005 |1.27 | |regnet_x_800mf|33.01 |37.466 |1.13 | |regnet_x_8gf|128.13 |134.868 |1.05 | |regnet_y_16gf|252.91 |261.292 |1.03 | |regnet_y_1_6gf|72.14 |84.22 |1.17 | |regnet_y_3_2gf|99.99 |109.296 |1.09 | |regnet_y_400mf|29.53 |36.506 |1.24 | |regnet_y_800mf|37.82 |43.634 |1.15 | |regnet_y_8gf|196.63 |203.317 |1.03 | |resnet101|128.80 |137.434 |1.07 | |resnet152|182.85 |196.498 |1.07 | |resnet18|29.06 |29.975 |1.03 | |resnet34|50.73 |53.443 |1.05 | |resnet50|76.88 |80.602 |1.05 | |resnext101_32x8d|277.29 |280.759 |1.01 | |resnext101_64x4d|269.56 |281.052 |1.04 | |resnext50_32x4d|100.73 |101.102 |1.00 | |shufflenet_v2_x0_5|10.56 |15.419 |1.46 | |shufflenet_v2_x1_0|13.11 |18.525 |1.41 | |shufflenet_v2_x1_5|18.05 |23.132 |1.28 | |shufflenet_v2_x2_0|25.04 |30.008 |1.20 | |squeezenet1_1|14.26 |14.325 |1.00 | |swin_b|264.52 |274.613 |1.04 | |swin_s|180.66 |188.914 |1.05 | |swin_t|108.62 |112.632 |1.04 | |swin_v2_s|220.29 |231.153 |1.05 | |swin_v2_t|127.27 |133.586 |1.05 | |vgg11|95.52 |103.714 |1.09 | |vgg11_bn|106.49 |120.711 |1.13 | |vgg13|132.94 |147.063 |1.11 | |vgg13_bn|149.73 |165.256 |1.10 | |vgg16|158.19 |172.865 |1.09 | |vgg16_bn|177.04 |192.888 |1.09 | |vgg19|184.76 |194.194 |1.05 | |vgg19_bn|203.30 |213.334 |1.05 | |vit_b_16|217.31 |219.748 |1.01 | |vit_b_32|69.47 |75.692 |1.09 | |vit_l_32|223.20 |258.487 |1.16 | |wide_resnet101_2|267.38 |279.836 |1.05 | |wide_resnet50_2|145.06 |154.918 |1.07 | You can see that in all cases it is faster than using `AveragedModel`. In fact in many cases, adding EMA does not add any overhead since the computation is hidden behind the usual iteration flow. This is a similar implementation to the one currently in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo). If the team is interested in merging this, let me know and I'll add some documentation similar to `swa_utils` and tests. Credits to @szmigacz for the implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94820 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6ab4ff35c
commit
45bf3f6216
@ -254,32 +254,73 @@ algorithms.
|
||||
lr_scheduler.OneCycleLR
|
||||
lr_scheduler.CosineAnnealingWarmRestarts
|
||||
|
||||
Stochastic Weight Averaging
|
||||
---------------------------
|
||||
Weight Averaging (SWA and EMA)
|
||||
------------------------------
|
||||
|
||||
:mod:`torch.optim.swa_utils` implements Stochastic Weight Averaging (SWA). In particular,
|
||||
:class:`torch.optim.swa_utils.AveragedModel` class implements SWA models,
|
||||
:mod:`torch.optim.swa_utils` implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). In particular,
|
||||
the :class:`torch.optim.swa_utils.AveragedModel` class implements SWA and EMA models,
|
||||
:class:`torch.optim.swa_utils.SWALR` implements the SWA learning rate scheduler and
|
||||
:func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA batch
|
||||
:func:`torch.optim.swa_utils.update_bn` is a utility function used to update SWA/EMA batch
|
||||
normalization statistics at the end of training.
|
||||
|
||||
SWA has been proposed in `Averaging Weights Leads to Wider Optima and Better Generalization`_.
|
||||
|
||||
EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed. It is a variation of `Polyak averaging`_, but using exponential weights instead of equal weights across iterations.
|
||||
|
||||
.. _`Averaging Weights Leads to Wider Optima and Better Generalization`: https://arxiv.org/abs/1803.05407
|
||||
|
||||
.. _`Polyak averaging`: https://paperswithcode.com/method/polyak-averaging
|
||||
|
||||
Constructing averaged models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
`AveragedModel` class serves to compute the weights of the SWA model. You can create an
|
||||
averaged model by running:
|
||||
The `AveragedModel` class serves to compute the weights of the SWA or EMA model.
|
||||
|
||||
>>> swa_model = AveragedModel(model)
|
||||
You can create an SWA averaged model by running:
|
||||
|
||||
Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``swa_model``
|
||||
>>> averaged_model = AveragedModel(model)
|
||||
|
||||
EMA models are constructed by specifying the ``multi_avg_fn`` argument as follows:
|
||||
|
||||
>>> decay = 0.999
|
||||
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
|
||||
|
||||
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to ``get_ema_multi_avg_fn``, the default is 0.999.
|
||||
|
||||
``get_ema_multi_avg_fn`` returns a function that applies the following EMA equation to the weights:
|
||||
|
||||
.. math:: W^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t
|
||||
|
||||
where alpha is the EMA decay.
|
||||
|
||||
Here the model ``model`` can be an arbitrary :class:`torch.nn.Module` object. ``averaged_model``
|
||||
will keep track of the running averages of the parameters of the ``model``. To update these
|
||||
averages, you can use the :func:`update_parameters` function:
|
||||
averages, you should use the :func:`update_parameters` function after the `optimizer.step()`:
|
||||
|
||||
>>> swa_model.update_parameters(model)
|
||||
>>> averaged_model.update_parameters(model)
|
||||
|
||||
For SWA and EMA, this call is usually done right after the optimizer ``step()``. In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training.
|
||||
|
||||
Custom averaging strategies
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of
|
||||
the parameters that you provide, but you can also use custom averaging functions with the
|
||||
``avg_fn`` or ``multi_avg_fn`` parameters:
|
||||
|
||||
- ``avg_fn`` allows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter.
|
||||
- ``multi_avg_fn`` allows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using the ``torch._foreach*`` functions. This function must update the averaged parameters in-place.
|
||||
|
||||
In the following example ``ema_model`` computes an exponential moving average using the ``avg_fn`` parameter:
|
||||
|
||||
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
|
||||
>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter
|
||||
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
|
||||
|
||||
|
||||
In the following example ``ema_model`` computes an exponential moving average using the more efficient ``multi_avg_fn`` parameter:
|
||||
|
||||
>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))
|
||||
|
||||
|
||||
SWA learning rate schedules
|
||||
@ -315,22 +356,10 @@ statistics for each batch normalization layer in the model.
|
||||
``swa_model`` by doing a forward pass with the ``swa_model`` on each element of the dataset.
|
||||
|
||||
|
||||
Custom averaging strategies
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
By default, :class:`torch.optim.swa_utils.AveragedModel` computes a running equal average of
|
||||
the parameters that you provide, but you can also use custom averaging functions with the
|
||||
``avg_fn`` parameter. In the following example ``ema_model`` computes an exponential moving average.
|
||||
|
||||
Example:
|
||||
|
||||
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
|
||||
>>> 0.1 * averaged_model_parameter + 0.9 * model_parameter
|
||||
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
|
||||
|
||||
|
||||
Putting it all together
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Putting it all together: SWA
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In the example below, ``swa_model`` is the SWA model that accumulates the averages of the weights.
|
||||
We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule
|
||||
@ -357,3 +386,26 @@ and start to collect SWA averages of the parameters at epoch 160:
|
||||
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
|
||||
>>> # Use swa_model to make predictions on test data
|
||||
>>> preds = swa_model(test_input)
|
||||
|
||||
|
||||
Putting it all together: EMA
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In the example below, ``ema_model`` is the EMA model that accumulates the exponentially-decayed averages of the weights with a decay rate of 0.999.
|
||||
We train the model for a total of 300 epochs and start to collect EMA averages immediately.
|
||||
|
||||
>>> loader, optimizer, model, loss_fn = ...
|
||||
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
|
||||
>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
|
||||
>>>
|
||||
>>> for epoch in range(300):
|
||||
>>> for input, target in loader:
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss_fn(model(input), target).backward()
|
||||
>>> optimizer.step()
|
||||
>>> ema_model.update_parameters(model)
|
||||
>>>
|
||||
>>> # Update bn statistics for the ema_model at the end
|
||||
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
|
||||
>>> # Use ema_model to make predictions on test data
|
||||
>>> preds = ema_model(test_input)
|
||||
|
@ -33,7 +33,7 @@ from torch.optim.lr_scheduler import (
|
||||
PolynomialLR,
|
||||
EPOCH_DEPRECATION_WARNING,
|
||||
)
|
||||
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
||||
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn, get_swa_multi_avg_fn, get_ema_multi_avg_fn
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
@ -3997,7 +3997,7 @@ class SWATestCNN(torch.nn.Module):
|
||||
|
||||
|
||||
class TestSWAUtils(TestCase):
|
||||
def _test_averaged_model(self, net_device, swa_device):
|
||||
def _test_averaged_model(self, net_device, swa_device, ema):
|
||||
dnn = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.ReLU(),
|
||||
@ -4010,32 +4010,48 @@ class TestSWAUtils(TestCase):
|
||||
torch.nn.Linear(5, 10),
|
||||
).to(net_device)
|
||||
|
||||
averaged_dnn = AveragedModel(dnn, device=swa_device)
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
p_avg += p.detach() / n_updates
|
||||
averaged_dnn.update_parameters(dnn)
|
||||
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema)
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
# Check that AveragedModel is on the correct device
|
||||
self.assertTrue(p_swa.device == swa_device)
|
||||
self.assertTrue(p.device == net_device)
|
||||
self.assertTrue(p_avg.device == net_device)
|
||||
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
|
||||
|
||||
def test_averaged_model_all_devices(self):
|
||||
def _run_averaged_steps(self, dnn, swa_device, ema):
|
||||
ema_decay = 0.999
|
||||
if ema:
|
||||
averaged_dnn = AveragedModel(dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay))
|
||||
else:
|
||||
averaged_dnn = AveragedModel(dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn())
|
||||
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
||||
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if ema:
|
||||
p_avg += p.detach() * ema_decay ** (n_updates - i - 1) * ((1 - ema_decay) if i > 0 else 1.0)
|
||||
else:
|
||||
p_avg += p.detach() / n_updates
|
||||
averaged_dnn.update_parameters(dnn)
|
||||
|
||||
return averaged_params, averaged_dnn
|
||||
|
||||
@parametrize("ema", [True, False])
|
||||
def test_averaged_model_all_devices(self, ema):
|
||||
cpu = torch.device("cpu")
|
||||
self._test_averaged_model(cpu, cpu)
|
||||
self._test_averaged_model(cpu, cpu, ema)
|
||||
if torch.cuda.is_available():
|
||||
cuda = torch.device(0)
|
||||
self._test_averaged_model(cuda, cpu)
|
||||
self._test_averaged_model(cpu, cuda)
|
||||
self._test_averaged_model(cuda, cuda)
|
||||
self._test_averaged_model(cuda, cpu, ema)
|
||||
self._test_averaged_model(cpu, cuda, ema)
|
||||
self._test_averaged_model(cuda, cuda, ema)
|
||||
|
||||
def test_averaged_model_mixed_device(self):
|
||||
@parametrize("ema", [True, False])
|
||||
def test_averaged_model_mixed_device(self, ema):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
dnn = torch.nn.Sequential(
|
||||
@ -4043,14 +4059,8 @@ class TestSWAUtils(TestCase):
|
||||
)
|
||||
dnn[0].cuda()
|
||||
dnn[1].cpu()
|
||||
averaged_dnn = AveragedModel(dnn)
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
p_avg += p.detach() / n_updates
|
||||
averaged_dnn.update_parameters(dnn)
|
||||
|
||||
averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema)
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
@ -4082,62 +4092,36 @@ class TestSWAUtils(TestCase):
|
||||
averaged_dnn = AveragedModel(dnn)
|
||||
pickle.dumps(averaged_dnn)
|
||||
|
||||
def test_averaged_model_exponential(self):
|
||||
# Test AveragedModel with EMA as avg_fn
|
||||
dnn = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3),
|
||||
torch.nn.Linear(5, 10),
|
||||
)
|
||||
alpha = 0.9
|
||||
|
||||
def avg_fn(p_avg, p, n_avg):
|
||||
return alpha * p_avg + (1 - alpha) * p
|
||||
|
||||
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn)
|
||||
averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
for p, p_avg in zip(dnn.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
updated_averaged_params.append(
|
||||
(p_avg * alpha + p * (1 - alpha)).clone()
|
||||
)
|
||||
for b in dnn.buffers():
|
||||
if b.size() != torch.Size([]):
|
||||
b.detach_().add_(torch.randn_like(b))
|
||||
|
||||
averaged_dnn.update_parameters(dnn)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
|
||||
self.assertEqual(b_avg, b_swa)
|
||||
|
||||
def test_averaged_model_exponential_buffers(self):
|
||||
@parametrize("use_multi_avg_fn", [True, False])
|
||||
@parametrize("use_buffers", [True, False])
|
||||
def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers):
|
||||
# Test AveragedModel with EMA as avg_fn and use_buffers as True.
|
||||
dnn = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||
torch.nn.BatchNorm2d(5, momentum=0.3),
|
||||
torch.nn.Linear(5, 10),
|
||||
)
|
||||
alpha = 0.9
|
||||
decay = 0.9
|
||||
|
||||
def avg_fn(p_avg, p, n_avg):
|
||||
return alpha * p_avg + (1 - alpha) * p
|
||||
if use_multi_avg_fn:
|
||||
averaged_dnn = AveragedModel(dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers)
|
||||
else:
|
||||
def avg_fn(p_avg, p, n_avg):
|
||||
return decay * p_avg + (1 - decay) * p
|
||||
|
||||
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)
|
||||
|
||||
if use_buffers:
|
||||
dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
|
||||
else:
|
||||
dnn_params = list(dnn.parameters())
|
||||
|
||||
averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True)
|
||||
dnn_params = itertools.chain(dnn.parameters(), dnn.buffers())
|
||||
averaged_params = [
|
||||
torch.zeros_like(param)
|
||||
for param in dnn_params
|
||||
if param.size() != torch.Size([])
|
||||
]
|
||||
|
||||
n_updates = 10
|
||||
for i in range(n_updates):
|
||||
updated_averaged_params = []
|
||||
@ -4149,18 +4133,24 @@ class TestSWAUtils(TestCase):
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
updated_averaged_params.append(
|
||||
(p_avg * alpha + p * (1 - alpha)).clone()
|
||||
(p_avg * decay + p * (1 - decay)).clone()
|
||||
)
|
||||
averaged_dnn.update_parameters(dnn)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
for p_avg, p_swa in zip(
|
||||
averaged_params,
|
||||
itertools.chain(
|
||||
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
|
||||
),
|
||||
):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
if use_buffers:
|
||||
for p_avg, p_swa in zip(
|
||||
averaged_params,
|
||||
itertools.chain(
|
||||
averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
|
||||
),
|
||||
):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
else:
|
||||
for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
|
||||
self.assertEqual(p_avg, p_swa)
|
||||
for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
|
||||
self.assertEqual(b_avg, b_swa)
|
||||
|
||||
def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
|
||||
|
||||
@ -4265,6 +4255,7 @@ class TestSWAUtils(TestCase):
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestLRScheduler)
|
||||
instantiate_parametrized_tests(TestSWAUtils)
|
||||
|
||||
|
||||
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
|
||||
|
@ -7,30 +7,87 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
__all__ = ['AveragedModel', 'update_bn', 'SWALR']
|
||||
__all__ = [
|
||||
'AveragedModel',
|
||||
'update_bn',
|
||||
'SWALR',
|
||||
'get_ema_multi_avg_fn',
|
||||
'get_swa_multi_avg_fn',
|
||||
'get_ema_avg_fn',
|
||||
'get_swa_avg_fn'
|
||||
]
|
||||
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
|
||||
|
||||
def get_ema_multi_avg_fn(decay=0.999):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param_list, current_param_list, _):
|
||||
# foreach lerp only handles float and complex
|
||||
if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(ema_param_list[0]):
|
||||
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
|
||||
else:
|
||||
for p_ema, p_model in zip(ema_param_list, current_param_list):
|
||||
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
|
||||
|
||||
return ema_update
|
||||
|
||||
|
||||
def get_swa_multi_avg_fn():
|
||||
@torch.no_grad()
|
||||
def swa_update(averaged_param_list, current_param_list, num_averaged):
|
||||
diffs = torch._foreach_sub(current_param_list, averaged_param_list)
|
||||
torch._foreach_addcdiv_(averaged_param_list, diffs, [num_averaged + 1] * len(averaged_param_list))
|
||||
|
||||
return swa_update
|
||||
|
||||
|
||||
def get_ema_avg_fn(decay=0.999):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param, current_param, num_averaged):
|
||||
return decay * ema_param + (1 - decay) * current_param
|
||||
|
||||
return ema_update
|
||||
|
||||
|
||||
def get_swa_avg_fn():
|
||||
@torch.no_grad()
|
||||
def swa_update(averaged_param, current_param, num_averaged):
|
||||
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
|
||||
|
||||
return swa_update
|
||||
|
||||
|
||||
class AveragedModel(Module):
|
||||
r"""Implements averaged model for Stochastic Weight Averaging (SWA).
|
||||
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and
|
||||
Exponential Moving Average (EMA).
|
||||
|
||||
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
|
||||
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
|
||||
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
|
||||
(UAI 2018).
|
||||
|
||||
Exponential Moving Average is a variation of `Polyak averaging`_,
|
||||
but using exponential weights instead of equal weights across iterations.
|
||||
|
||||
AveragedModel class creates a copy of the provided module :attr:`model`
|
||||
on the device :attr:`device` and allows to compute running averages of the
|
||||
parameters of the :attr:`model`.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to use with SWA
|
||||
model (torch.nn.Module): model to use with SWA/EMA
|
||||
device (torch.device, optional): if provided, the averaged model will be
|
||||
stored on the :attr:`device`
|
||||
avg_fn (function, optional): the averaging function used to update
|
||||
parameters; the function must take in the current value of the
|
||||
:class:`AveragedModel` parameter, the current value of :attr:`model`
|
||||
parameter and the number of models already averaged; if None,
|
||||
equally weighted average is used (default: None)
|
||||
parameter, and the number of models already averaged; if None,
|
||||
an equally weighted average is used (default: None)
|
||||
multi_avg_fn (function, optional): the averaging function used to update
|
||||
parameters inplace; the function must take in the current values of the
|
||||
:class:`AveragedModel` parameters as a list, the current values of :attr:`model`
|
||||
parameters as a list, and the number of models already averaged; if None,
|
||||
an equally weighted average is used (default: None)
|
||||
use_buffers (bool): if ``True``, it will compute running averages for
|
||||
both the parameters and the buffers of the model. (default: ``False``)
|
||||
|
||||
@ -56,19 +113,18 @@ class AveragedModel(Module):
|
||||
>>> # Update bn statistics for the swa_model at the end
|
||||
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
|
||||
|
||||
You can also use custom averaging functions with `avg_fn` parameter.
|
||||
You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
|
||||
If no averaging function is provided, the default is to compute
|
||||
equally-weighted average of the weights.
|
||||
equally-weighted average of the weights (SWA).
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("undefined variables")
|
||||
>>> # Compute exponential moving averages of the weights and buffers
|
||||
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged: (
|
||||
... 0.1 * averaged_model_parameter + 0.9 * model_parameter)
|
||||
>>> swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg, use_buffers=True)
|
||||
>>> ema_model = torch.optim.swa_utils.AveragedModel(model,
|
||||
>>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
|
||||
|
||||
.. note::
|
||||
When using SWA with models containing Batch Normalization you may
|
||||
When using SWA/EMA with models containing Batch Normalization you may
|
||||
need to update the activation statistics for Batch Normalization.
|
||||
This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
|
||||
or by setting :attr:`use_buffers` to `True`. The first approach updates the
|
||||
@ -79,7 +135,7 @@ class AveragedModel(Module):
|
||||
approach yields the best results in your problem.
|
||||
|
||||
.. note::
|
||||
:attr:`avg_fn` is not saved in the :meth:`state_dict` of the model.
|
||||
:attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
|
||||
|
||||
.. note::
|
||||
When :meth:`update_parameters` is called for the first time (i.e.
|
||||
@ -98,15 +154,19 @@ class AveragedModel(Module):
|
||||
.. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
|
||||
Generalizes Well:
|
||||
https://arxiv.org/abs/2001.02312
|
||||
.. _Polyak averaging:
|
||||
https://paperswithcode.com/method/polyak-averaging
|
||||
"""
|
||||
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
|
||||
def __init__(self, model, device=None, avg_fn=None, multi_avg_fn=None, use_buffers=False):
|
||||
super().__init__()
|
||||
assert avg_fn is None or multi_avg_fn is None, 'Only one of avg_fn and multi_avg_fn should be provided'
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
self.module = self.module.to(device)
|
||||
self.register_buffer('n_averaged',
|
||||
torch.tensor(0, dtype=torch.long, device=device))
|
||||
self.avg_fn = avg_fn
|
||||
self.multi_avg_fn = multi_avg_fn
|
||||
self.use_buffers = use_buffers
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -121,28 +181,40 @@ class AveragedModel(Module):
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.use_buffers else model.parameters()
|
||||
)
|
||||
for p_swa, p_model in zip(self_param, model_param):
|
||||
device = p_swa.device
|
||||
self_param_detached = []
|
||||
model_param_detached = []
|
||||
for p_averaged, p_model in zip(self_param, model_param):
|
||||
device = p_averaged.device
|
||||
p_model_ = p_model.detach().to(device)
|
||||
self_param_detached.append(p_averaged.detach())
|
||||
model_param_detached.append(p_model_)
|
||||
if self.n_averaged == 0:
|
||||
p_swa.detach().copy_(p_model_)
|
||||
p_averaged.detach().copy_(p_model_)
|
||||
|
||||
if self.n_averaged > 0:
|
||||
if self.multi_avg_fn is not None or self.avg_fn is None:
|
||||
grouped_tensors = _group_tensors_by_device_and_dtype([self_param_detached, model_param_detached])
|
||||
for ((device, _), [self_params, model_params]) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
|
||||
elif device.type == 'cuda':
|
||||
multi_avg_fn = get_swa_multi_avg_fn()
|
||||
multi_avg_fn(self_params, model_params, self.n_averaged.to(device))
|
||||
else:
|
||||
avg_fn = get_swa_avg_fn()
|
||||
n_averaged = self.n_averaged.to(device)
|
||||
for p_averaged, p_model in zip(self_params, model_params):
|
||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||
else:
|
||||
if self.avg_fn is None:
|
||||
p_swa.detach().copy_(
|
||||
p_swa.detach()
|
||||
+ (p_model_ - p_swa.detach()) / (self.n_averaged.to(device) + 1)
|
||||
)
|
||||
else:
|
||||
p_swa.detach().copy_(
|
||||
self.avg_fn(
|
||||
p_swa.detach(), p_model_, self.n_averaged.to(device)
|
||||
)
|
||||
)
|
||||
for p_averaged, p_model in zip(self_param_detached, model_param_detached):
|
||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||
p_averaged.detach().copy_(self.avg_fn(p_averaged.detach(), p_model, n_averaged))
|
||||
|
||||
if not self.use_buffers:
|
||||
# If not apply running averages to the buffers,
|
||||
# keep the buffers in sync with the source model.
|
||||
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
|
||||
b_swa.detach().copy_(b_model.detach().to(device))
|
||||
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
|
||||
self.n_averaged += 1
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user