sync AveragedModel buffers when use_buffers=False (#84054)

Fixes #84053

As described in the issue, the AveragedModel will deep copy the model during initialization, which means that the buffers in the averaged model cannot be updated together with the model.

One solution is to make the buffers equal to the source model every time when calling `update_parameters`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84054
Approved by: https://github.com/samdow
This commit is contained in:
RangiLyu
2022-10-24 16:03:11 +00:00
committed by PyTorch MergeBot
parent 1bcd63d5e1
commit 512a3a48e3
2 changed files with 12 additions and 0 deletions

View File

@ -132,6 +132,11 @@ class AveragedModel(Module):
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
if not self.use_buffers:
# If not apply running averages to the buffers,
# keep the buffers in sync with the source model.
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
b_swa.detach().copy_(b_model.detach().to(device))
self.n_averaged += 1