mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added option to update parameters using state_dict in AveragedModel (#65495)
Summary: While implementing [EMA](https://github.com/pytorch/vision/pull/4381)(which extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](https://github.com/pytorch/vision/pull/4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation. Discussion: https://github.com/pytorch/vision/pull/4406#pullrequestreview-753734102 Pull Request resolved: https://github.com/pytorch/pytorch/pull/65495 Reviewed By: datumbox Differential Revision: D31176742 Pulled By: prabhat00155 fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3324bae5f1
commit
2ea724b1fd
@ -26,6 +26,8 @@ class AveragedModel(Module):
|
||||
:class:`AveragedModel` parameter, the current value of :attr:`model`
|
||||
parameter and the number of models already averaged; if None,
|
||||
equally weighted average is used (default: None)
|
||||
mode (str, optional): whether to use parameters or state_dict for update
|
||||
(default: parameters)
|
||||
|
||||
Example:
|
||||
>>> loader, optimizer, model, loss_fn = ...
|
||||
@ -84,7 +86,7 @@ class AveragedModel(Module):
|
||||
Generalizes Well:
|
||||
https://arxiv.org/abs/2001.02312
|
||||
"""
|
||||
def __init__(self, model, device=None, avg_fn=None):
|
||||
def __init__(self, model, device=None, avg_fn=None, mode='parameters'):
|
||||
super(AveragedModel, self).__init__()
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
@ -96,12 +98,15 @@ class AveragedModel(Module):
|
||||
return averaged_model_parameter + \
|
||||
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
|
||||
self.avg_fn = avg_fn
|
||||
self.use_state_dict = mode == 'state_dict'
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def update_parameters(self, model):
|
||||
for p_swa, p_model in zip(self.parameters(), model.parameters()):
|
||||
self_param = self.module.state_dict().values() if self.use_state_dict else self.parameters()
|
||||
model_param = model.state_dict().values() if self.use_state_dict else model.parameters()
|
||||
for p_swa, p_model in zip(self_param, model_param):
|
||||
device = p_swa.device
|
||||
p_model_ = p_model.detach().to(device)
|
||||
if self.n_averaged == 0:
|
||||
|
Reference in New Issue
Block a user