Files
pytorch/torch/distributed/algorithms/model_averaging/utils.py
Aaron Gokaslan 3555ebb63d [BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
2025-05-12 18:30:52 +00:00

93 lines
3.1 KiB
Python

# mypy: allow-untyped-defs
import itertools
from collections.abc import Iterable, Iterator
from typing import Union
import torch
import torch.distributed as dist
# The two imports below are not always available depending on the
# USE_DISTRIBUTED compile flag. Make sure they raise import error
# if we're trying to use them.
from torch.distributed import group, ProcessGroup
__all__ = [
"average_parameters",
"get_params_to_average",
"average_parameters_or_parameter_groups",
]
def average_parameters(
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
):
"""
Averages all the given parameters.
For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
Thus, it requires extra memory of the same size as the given parameters.
"""
group_to_use = process_group if process_group is not None else group.WORLD
# Do not update any parameter if not in the process group.
if dist._rank_not_in_group(group_to_use):
return
params_it1, params_it2 = itertools.tee(params)
# If the input parameters have different data types,
# packing these parameters will trigger an implicit type up-casting.
# The original parameter data types will be restored during the subsequent unpacking.
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
flat_params /= dist.get_world_size(group_to_use)
# Make sure the allreduce will not conflict with any other ongoing process group.
if torch.accelerator.is_available():
torch.accelerator.synchronize()
dist.all_reduce(flat_params, group=group_to_use)
offset = 0
for p in params_it2:
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
offset += p.numel()
def get_params_to_average(
params: Union[
Iterable[torch.nn.Parameter],
Iterable[dict[str, torch.nn.Parameter]],
],
):
"""
Return a list of parameters that need to average.
This filters out the parameters that do not contain any gradients.
Args:
params: The parameters of a model or parameter groups of an optimizer.
"""
filtered_params = []
for param in params:
if isinstance(param, torch.nn.Parameter):
# model.parameters() input
param_data = param
if param_data.grad is not None:
filtered_params.append(param_data)
elif isinstance(param, dict):
# optimizer.param_groups input
for param_data in param["params"]:
if param_data.grad is not None:
filtered_params.append(param_data)
else:
raise NotImplementedError(
f"Parameter input of type {type(param)} is not supported"
)
return filtered_params
def average_parameters_or_parameter_groups(
params: Union[
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
process_group: ProcessGroup,
):
"""Averages parameters of a model or parameter groups of an optimizer."""
average_parameters(iter(get_params_to_average(params)), process_group)