swa avoid stream sync (#157705)

Summary:
When AveragedModel updates_parameters it calls self.n_averaged == 0 for each parameter, where n_averated is a buffer on GPU. Moving check before the cycle to call sync once

It improves update_parameter from 74ms to 57ms ~22% improvement
{F1980011097}
{F1980011111}

Test Plan:
CI

Rollback Plan:

Differential Revision: D77723025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157705
Approved by: https://github.com/albanD, https://github.com/Skylion007, https://github.com/janeyx99
This commit is contained in:
Olga Gerasimova
2025-07-07 20:47:35 +00:00
committed by PyTorch MergeBot
parent c2510fcd86
commit 2efa5eaa65

View File

@ -259,11 +259,12 @@ class AveragedModel(Module):
)
self_param_detached: list[Optional[Tensor]] = []
model_param_detached: list[Optional[Tensor]] = []
copy_param = bool(self.n_averaged == 0)
for p_averaged, p_model in zip(self_param, model_param):
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
if self.n_averaged == 0:
if copy_param:
p_averaged.detach().copy_(p_model_)
if self.n_averaged > 0: