Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](https://github.com/pytorch/vision/pull/4381)(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](https://github.com/pytorch/vision/pull/4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: https://github.com/pytorch/vision/pull/4406#pullrequestreview-753734102

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
This commit is contained in:
Prabhat Roy
2021-09-28 03:33:16 -07:00
committed by Facebook GitHub Bot
parent 3324bae5f1
commit 2ea724b1fd
2 changed files with 39 additions and 2 deletions

View File

@ -26,6 +26,8 @@ class AveragedModel(Module):
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
mode (str, optional): whether to use parameters or state_dict for update
(default: parameters)
Example:
>>> loader, optimizer, model, loss_fn = ...
@ -84,7 +86,7 @@ class AveragedModel(Module):
Generalizes Well:
https://arxiv.org/abs/2001.02312
"""
def __init__(self, model, device=None, avg_fn=None):
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
@ -96,12 +98,15 @@ class AveragedModel(Module):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_state_dict = mode == 'state_dict'
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0: