remove unnecessary sync point in AveragedModel update (#158017)

Summary:
The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158017
Approved by: https://github.com/janeyx99
This commit is contained in:
Gael Le Lan
2025-09-16 18:57:52 +00:00
committed by PyTorch MergeBot
parent 5937861eba
commit cb7f45fd34
2 changed files with 78 additions and 13 deletions

View File

@ -76,7 +76,6 @@ class TestSWAUtils(TestCase):
# Check that AveragedModel is on the correct device # Check that AveragedModel is on the correct device
self.assertTrue(p_swa.device == swa_device) self.assertTrue(p_swa.device == swa_device)
self.assertTrue(p_avg.device == net_device) self.assertTrue(p_avg.device == net_device)
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
def _run_averaged_steps(self, dnn, swa_device, ema): def _run_averaged_steps(self, dnn, swa_device, ema):
ema_decay = 0.999 ema_decay = 0.999
@ -150,6 +149,44 @@ class TestSWAUtils(TestCase):
self.assertEqual(p_swa, p_swa2) self.assertEqual(p_swa, p_swa2)
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
def test_averaged_model_backward_compatibility(self):
"""Test that AveragedModel correctly handles old checkpoints with tensor n_averaged."""
dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
)
averaged_dnn = AveragedModel(dnn)
# Update parameters a few times
n_updates = 5
for _ in range(n_updates):
for p in dnn.parameters():
p.detach().add_(torch.randn_like(p))
averaged_dnn.update_parameters(dnn)
# Manually create a state dict with tensor n_averaged (simulating old checkpoint)
state_dict = averaged_dnn.state_dict()
# Create an old-style tensor n_averaged
old_n_averaged = torch.tensor(n_updates, dtype=torch.long)
state_dict["n_averaged"] = old_n_averaged
# Create new model and load the old-style state dict
averaged_dnn2 = AveragedModel(dnn)
averaged_dnn2.load_state_dict(state_dict)
# Check that n_averaged was correctly loaded as a Python int
self.assertEqual(averaged_dnn2.n_averaged, n_updates)
self.assertIsInstance(averaged_dnn2.n_averaged, int)
# Verify that parameters are correctly loaded
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
self.assertEqual(p_swa, p_swa2)
# Test that we can continue to update parameters without issues
for p in dnn.parameters():
p.detach().add_(torch.randn_like(p))
averaged_dnn2.update_parameters(dnn)
self.assertEqual(averaged_dnn2.n_averaged, n_updates + 1)
def test_averaged_model_default_avg_fn_picklable(self): def test_averaged_model_default_avg_fn_picklable(self):
dnn = torch.nn.Sequential( dnn = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Conv2d(1, 5, kernel_size=3),

View File

@ -116,6 +116,28 @@ def get_swa_avg_fn():
return swa_update return swa_update
def _load_state_dict_pre_hook(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Pre-hook to handle backward compatibility with tensor n_averaged."""
# Check if the old tensor n_averaged is present in the state dict
n_averaged_key = prefix + "n_averaged"
if n_averaged_key in state_dict:
# Convert tensor n_averaged to Python int for backward compatibility
n_averaged_tensor = state_dict[n_averaged_key]
if isinstance(n_averaged_tensor, Tensor):
module.n_averaged = int(n_averaged_tensor.item())
# Remove the old tensor buffer from state_dict to avoid loading it
del state_dict[n_averaged_key]
class AveragedModel(Module): class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
@ -215,7 +237,7 @@ class AveragedModel(Module):
https://paperswithcode.com/method/polyak-averaging https://paperswithcode.com/method/polyak-averaging
""" """
n_averaged: Tensor n_averaged: int
def __init__( def __init__(
self, self,
@ -234,17 +256,25 @@ class AveragedModel(Module):
self.module = deepcopy(model) self.module = deepcopy(model)
if device is not None: if device is not None:
self.module = self.module.to(device) self.module = self.module.to(device)
self.register_buffer( self.n_averaged = 0
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
)
self.avg_fn = avg_fn self.avg_fn = avg_fn
self.multi_avg_fn = multi_avg_fn self.multi_avg_fn = multi_avg_fn
self.use_buffers = use_buffers self.use_buffers = use_buffers
self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""Forward pass.""" """Forward pass."""
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
def get_extra_state(self) -> Any:
"""Get extra state for serialization."""
return {"n_averaged": self.n_averaged}
def set_extra_state(self, state: Any) -> None:
"""Set extra state from deserialization."""
if isinstance(state, dict) and "n_averaged" in state:
self.n_averaged = state["n_averaged"]
def update_parameters(self, model: Module): def update_parameters(self, model: Module):
"""Update model parameters.""" """Update model parameters."""
self_param = ( self_param = (
@ -280,28 +310,26 @@ class AveragedModel(Module):
self.multi_avg_fn( self.multi_avg_fn(
self_params, # type: ignore[arg-type] self_params, # type: ignore[arg-type]
model_params, # type: ignore[arg-type] model_params, # type: ignore[arg-type]
self.n_averaged.to(device), self.n_averaged,
) )
elif ( elif (
device is not None device is not None
and device.type in _get_foreach_kernels_supported_devices() and device.type in _get_foreach_kernels_supported_devices()
): ):
multi_avg_fn = get_swa_multi_avg_fn() multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn( multi_avg_fn(self_params, model_params, self.n_averaged)
self_params, model_params, self.n_averaged.to(device)
)
else: else:
avg_fn = get_swa_avg_fn() avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) p_averaged.copy_(
avg_fn(p_averaged, p_model, self.n_averaged)
)
else: else:
for p_averaged, p_model in zip( # type: ignore[assignment] for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached self_param_detached, model_param_detached
): ):
n_averaged = self.n_averaged.to(p_averaged.device)
p_averaged.detach().copy_( p_averaged.detach().copy_(
self.avg_fn(p_averaged.detach(), p_model, n_averaged) self.avg_fn(p_averaged.detach(), p_model, self.n_averaged)
) )
if not self.use_buffers: if not self.use_buffers: