Revert "remove unnecessary sync point in AveragedModel update (#158017)"

This reverts commit cb7f45fd34b890fa7665837573ebb25744889568.

Reverted https://github.com/pytorch/pytorch/pull/158017 on behalf of https://github.com/wdvr due to discussed with author - expecting this to break checkpointing ([comment](https://github.com/pytorch/pytorch/pull/158017#issuecomment-3301790645))
This commit is contained in:
PyTorch MergeBot
2025-09-17 08:02:02 +00:00
parent a63221a335
commit a5419743c6
2 changed files with 13 additions and 78 deletions

View File

@ -116,28 +116,6 @@ def get_swa_avg_fn():
return swa_update
def _load_state_dict_pre_hook(
module,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Pre-hook to handle backward compatibility with tensor n_averaged."""
# Check if the old tensor n_averaged is present in the state dict
n_averaged_key = prefix + "n_averaged"
if n_averaged_key in state_dict:
# Convert tensor n_averaged to Python int for backward compatibility
n_averaged_tensor = state_dict[n_averaged_key]
if isinstance(n_averaged_tensor, Tensor):
module.n_averaged = int(n_averaged_tensor.item())
# Remove the old tensor buffer from state_dict to avoid loading it
del state_dict[n_averaged_key]
class AveragedModel(Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
@ -237,7 +215,7 @@ class AveragedModel(Module):
https://paperswithcode.com/method/polyak-averaging
"""
n_averaged: int
n_averaged: Tensor
def __init__(
self,
@ -256,25 +234,17 @@ class AveragedModel(Module):
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.n_averaged = 0
self.register_buffer(
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
)
self.avg_fn = avg_fn
self.multi_avg_fn = multi_avg_fn
self.use_buffers = use_buffers
self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook)
def forward(self, *args, **kwargs):
"""Forward pass."""
return self.module(*args, **kwargs)
def get_extra_state(self) -> Any:
"""Get extra state for serialization."""
return {"n_averaged": self.n_averaged}
def set_extra_state(self, state: Any) -> None:
"""Set extra state from deserialization."""
if isinstance(state, dict) and "n_averaged" in state:
self.n_averaged = state["n_averaged"]
def update_parameters(self, model: Module):
"""Update model parameters."""
self_param = (
@ -310,26 +280,28 @@ class AveragedModel(Module):
self.multi_avg_fn(
self_params, # type: ignore[arg-type]
model_params, # type: ignore[arg-type]
self.n_averaged,
self.n_averaged.to(device),
)
elif (
device is not None
and device.type in _get_foreach_kernels_supported_devices()
):
multi_avg_fn = get_swa_multi_avg_fn()
multi_avg_fn(self_params, model_params, self.n_averaged)
multi_avg_fn(
self_params, model_params, self.n_averaged.to(device)
)
else:
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
p_averaged.copy_(
avg_fn(p_averaged, p_model, self.n_averaged)
)
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached
):
n_averaged = self.n_averaged.to(p_averaged.device)
p_averaged.detach().copy_(
self.avg_fn(p_averaged.detach(), p_model, self.n_averaged)
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
)
if not self.use_buffers: