mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
swa avoid stream sync (#157705)
Summary: When AveragedModel updates_parameters it calls self.n_averaged == 0 for each parameter, where n_averated is a buffer on GPU. Moving check before the cycle to call sync once It improves update_parameter from 74ms to 57ms ~22% improvement {F1980011097} {F1980011111} Test Plan: CI Rollback Plan: Differential Revision: D77723025 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157705 Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
c2510fcd86
commit
2efa5eaa65
@ -259,11 +259,12 @@ class AveragedModel(Module):
|
||||
)
|
||||
self_param_detached: list[Optional[Tensor]] = []
|
||||
model_param_detached: list[Optional[Tensor]] = []
|
||||
copy_param = bool(self.n_averaged == 0)
|
||||
for p_averaged, p_model in zip(self_param, model_param):
|
||||
p_model_ = p_model.detach().to(p_averaged.device)
|
||||
self_param_detached.append(p_averaged.detach())
|
||||
model_param_detached.append(p_model_)
|
||||
if self.n_averaged == 0:
|
||||
if copy_param:
|
||||
p_averaged.detach().copy_(p_model_)
|
||||
|
||||
if self.n_averaged > 0:
|
||||
|
Reference in New Issue
Block a user