mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
sync AveragedModel buffers when use_buffers=False (#84054)
Fixes #84053 As described in the issue, the AveragedModel will deep copy the model during initialization, which means that the buffers in the averaged model cannot be updated together with the model. One solution is to make the buffers equal to the source model every time when calling `update_parameters`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84054 Approved by: https://github.com/samdow
This commit is contained in:
committed by
PyTorch MergeBot
parent
1bcd63d5e1
commit
512a3a48e3
@ -132,6 +132,11 @@ class AveragedModel(Module):
|
||||
else:
|
||||
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
|
||||
self.n_averaged.to(device)))
|
||||
if not self.use_buffers:
|
||||
# If not apply running averages to the buffers,
|
||||
# keep the buffers in sync with the source model.
|
||||
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
|
||||
b_swa.detach().copy_(b_model.detach().to(device))
|
||||
self.n_averaged += 1
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user