mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
remove unnecessary sync point in AveragedModel update (#158017)
Summary: The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update. This test is only meant to know whether the AveragedModel copy has been initialized or not. This diff introduces a CPU-based variable for that purpose. When loading from checkpoint we also make sure the parameter is refreshed. After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction). Test Plan: contbuild & OSS CI Test plan from GitHub: CI Rollback Plan: Differential Revision: D78074709 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158017 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
5937861eba
commit
cb7f45fd34
@ -76,7 +76,6 @@ class TestSWAUtils(TestCase):
|
|||||||
# Check that AveragedModel is on the correct device
|
# Check that AveragedModel is on the correct device
|
||||||
self.assertTrue(p_swa.device == swa_device)
|
self.assertTrue(p_swa.device == swa_device)
|
||||||
self.assertTrue(p_avg.device == net_device)
|
self.assertTrue(p_avg.device == net_device)
|
||||||
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
|
|
||||||
|
|
||||||
def _run_averaged_steps(self, dnn, swa_device, ema):
|
def _run_averaged_steps(self, dnn, swa_device, ema):
|
||||||
ema_decay = 0.999
|
ema_decay = 0.999
|
||||||
@ -150,6 +149,44 @@ class TestSWAUtils(TestCase):
|
|||||||
self.assertEqual(p_swa, p_swa2)
|
self.assertEqual(p_swa, p_swa2)
|
||||||
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
|
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
|
||||||
|
|
||||||
|
def test_averaged_model_backward_compatibility(self):
|
||||||
|
"""Test that AveragedModel correctly handles old checkpoints with tensor n_averaged."""
|
||||||
|
dnn = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
|
||||||
|
)
|
||||||
|
averaged_dnn = AveragedModel(dnn)
|
||||||
|
|
||||||
|
# Update parameters a few times
|
||||||
|
n_updates = 5
|
||||||
|
for _ in range(n_updates):
|
||||||
|
for p in dnn.parameters():
|
||||||
|
p.detach().add_(torch.randn_like(p))
|
||||||
|
averaged_dnn.update_parameters(dnn)
|
||||||
|
|
||||||
|
# Manually create a state dict with tensor n_averaged (simulating old checkpoint)
|
||||||
|
state_dict = averaged_dnn.state_dict()
|
||||||
|
# Create an old-style tensor n_averaged
|
||||||
|
old_n_averaged = torch.tensor(n_updates, dtype=torch.long)
|
||||||
|
state_dict["n_averaged"] = old_n_averaged
|
||||||
|
|
||||||
|
# Create new model and load the old-style state dict
|
||||||
|
averaged_dnn2 = AveragedModel(dnn)
|
||||||
|
averaged_dnn2.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Check that n_averaged was correctly loaded as a Python int
|
||||||
|
self.assertEqual(averaged_dnn2.n_averaged, n_updates)
|
||||||
|
self.assertIsInstance(averaged_dnn2.n_averaged, int)
|
||||||
|
|
||||||
|
# Verify that parameters are correctly loaded
|
||||||
|
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
|
||||||
|
self.assertEqual(p_swa, p_swa2)
|
||||||
|
|
||||||
|
# Test that we can continue to update parameters without issues
|
||||||
|
for p in dnn.parameters():
|
||||||
|
p.detach().add_(torch.randn_like(p))
|
||||||
|
averaged_dnn2.update_parameters(dnn)
|
||||||
|
self.assertEqual(averaged_dnn2.n_averaged, n_updates + 1)
|
||||||
|
|
||||||
def test_averaged_model_default_avg_fn_picklable(self):
|
def test_averaged_model_default_avg_fn_picklable(self):
|
||||||
dnn = torch.nn.Sequential(
|
dnn = torch.nn.Sequential(
|
||||||
torch.nn.Conv2d(1, 5, kernel_size=3),
|
torch.nn.Conv2d(1, 5, kernel_size=3),
|
||||||
|
@ -116,6 +116,28 @@ def get_swa_avg_fn():
|
|||||||
return swa_update
|
return swa_update
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_pre_hook(
|
||||||
|
module,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
"""Pre-hook to handle backward compatibility with tensor n_averaged."""
|
||||||
|
# Check if the old tensor n_averaged is present in the state dict
|
||||||
|
n_averaged_key = prefix + "n_averaged"
|
||||||
|
if n_averaged_key in state_dict:
|
||||||
|
# Convert tensor n_averaged to Python int for backward compatibility
|
||||||
|
n_averaged_tensor = state_dict[n_averaged_key]
|
||||||
|
if isinstance(n_averaged_tensor, Tensor):
|
||||||
|
module.n_averaged = int(n_averaged_tensor.item())
|
||||||
|
# Remove the old tensor buffer from state_dict to avoid loading it
|
||||||
|
del state_dict[n_averaged_key]
|
||||||
|
|
||||||
|
|
||||||
class AveragedModel(Module):
|
class AveragedModel(Module):
|
||||||
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
|
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
|
||||||
|
|
||||||
@ -215,7 +237,7 @@ class AveragedModel(Module):
|
|||||||
https://paperswithcode.com/method/polyak-averaging
|
https://paperswithcode.com/method/polyak-averaging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n_averaged: Tensor
|
n_averaged: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -234,17 +256,25 @@ class AveragedModel(Module):
|
|||||||
self.module = deepcopy(model)
|
self.module = deepcopy(model)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.module = self.module.to(device)
|
self.module = self.module.to(device)
|
||||||
self.register_buffer(
|
self.n_averaged = 0
|
||||||
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
|
|
||||||
)
|
|
||||||
self.avg_fn = avg_fn
|
self.avg_fn = avg_fn
|
||||||
self.multi_avg_fn = multi_avg_fn
|
self.multi_avg_fn = multi_avg_fn
|
||||||
self.use_buffers = use_buffers
|
self.use_buffers = use_buffers
|
||||||
|
self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""Forward pass."""
|
"""Forward pass."""
|
||||||
return self.module(*args, **kwargs)
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_extra_state(self) -> Any:
|
||||||
|
"""Get extra state for serialization."""
|
||||||
|
return {"n_averaged": self.n_averaged}
|
||||||
|
|
||||||
|
def set_extra_state(self, state: Any) -> None:
|
||||||
|
"""Set extra state from deserialization."""
|
||||||
|
if isinstance(state, dict) and "n_averaged" in state:
|
||||||
|
self.n_averaged = state["n_averaged"]
|
||||||
|
|
||||||
def update_parameters(self, model: Module):
|
def update_parameters(self, model: Module):
|
||||||
"""Update model parameters."""
|
"""Update model parameters."""
|
||||||
self_param = (
|
self_param = (
|
||||||
@ -280,28 +310,26 @@ class AveragedModel(Module):
|
|||||||
self.multi_avg_fn(
|
self.multi_avg_fn(
|
||||||
self_params, # type: ignore[arg-type]
|
self_params, # type: ignore[arg-type]
|
||||||
model_params, # type: ignore[arg-type]
|
model_params, # type: ignore[arg-type]
|
||||||
self.n_averaged.to(device),
|
self.n_averaged,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
device is not None
|
device is not None
|
||||||
and device.type in _get_foreach_kernels_supported_devices()
|
and device.type in _get_foreach_kernels_supported_devices()
|
||||||
):
|
):
|
||||||
multi_avg_fn = get_swa_multi_avg_fn()
|
multi_avg_fn = get_swa_multi_avg_fn()
|
||||||
multi_avg_fn(
|
multi_avg_fn(self_params, model_params, self.n_averaged)
|
||||||
self_params, model_params, self.n_averaged.to(device)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
avg_fn = get_swa_avg_fn()
|
avg_fn = get_swa_avg_fn()
|
||||||
n_averaged = self.n_averaged.to(device)
|
|
||||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
p_averaged.copy_(
|
||||||
|
avg_fn(p_averaged, p_model, self.n_averaged)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||||
self_param_detached, model_param_detached
|
self_param_detached, model_param_detached
|
||||||
):
|
):
|
||||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
|
||||||
p_averaged.detach().copy_(
|
p_averaged.detach().copy_(
|
||||||
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
self.avg_fn(p_averaged.detach(), p_model, self.n_averaged)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.use_buffers:
|
if not self.use_buffers:
|
||||||
|
Reference in New Issue
Block a user