mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit a0d026688cd69583d5a4e0c6f3e5fda141a7f4a9. Revert "Always build USE_DISTRIBUTED. (#160449)" This reverts commit d80297a6846f1f2c36fd4f19e22919f2abe8fcea. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162568 Approved by: https://github.com/huydhn
93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
# mypy: allow-untyped-defs
|
|
import itertools
|
|
from collections.abc import Iterable, Iterator
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
# The two imports below are not always available depending on the
|
|
# USE_DISTRIBUTED compile flag. Make sure they raise import error
|
|
# if we're trying to use them.
|
|
from torch.distributed import group, ProcessGroup
|
|
|
|
|
|
__all__ = [
|
|
"average_parameters",
|
|
"get_params_to_average",
|
|
"average_parameters_or_parameter_groups",
|
|
]
|
|
|
|
|
|
def average_parameters(
|
|
params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
|
|
):
|
|
"""
|
|
Averages all the given parameters.
|
|
|
|
For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
|
|
Thus, it requires extra memory of the same size as the given parameters.
|
|
"""
|
|
group_to_use = process_group if process_group is not None else group.WORLD
|
|
# Do not update any parameter if not in the process group.
|
|
if dist._rank_not_in_group(group_to_use):
|
|
return
|
|
|
|
params_it1, params_it2 = itertools.tee(params)
|
|
# If the input parameters have different data types,
|
|
# packing these parameters will trigger an implicit type up-casting.
|
|
# The original parameter data types will be restored during the subsequent unpacking.
|
|
flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
|
|
flat_params /= dist.get_world_size(group_to_use)
|
|
# Make sure the allreduce will not conflict with any other ongoing process group.
|
|
if torch.accelerator.is_available():
|
|
torch.accelerator.synchronize()
|
|
dist.all_reduce(flat_params, group=group_to_use)
|
|
|
|
offset = 0
|
|
for p in params_it2:
|
|
p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
|
|
offset += p.numel()
|
|
|
|
|
|
def get_params_to_average(
|
|
params: Union[
|
|
Iterable[torch.nn.Parameter],
|
|
Iterable[dict[str, torch.nn.Parameter]],
|
|
],
|
|
):
|
|
"""
|
|
Return a list of parameters that need to average.
|
|
|
|
This filters out the parameters that do not contain any gradients.
|
|
Args:
|
|
params: The parameters of a model or parameter groups of an optimizer.
|
|
"""
|
|
filtered_params = []
|
|
for param in params:
|
|
if isinstance(param, torch.nn.Parameter):
|
|
# model.parameters() input
|
|
param_data = param
|
|
if param_data.grad is not None:
|
|
filtered_params.append(param_data)
|
|
elif isinstance(param, dict):
|
|
# optimizer.param_groups input
|
|
for param_data in param["params"]:
|
|
if param_data.grad is not None:
|
|
filtered_params.append(param_data)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Parameter input of type {type(param)} is not supported"
|
|
)
|
|
return filtered_params
|
|
|
|
|
|
def average_parameters_or_parameter_groups(
|
|
params: Union[
|
|
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
|
],
|
|
process_group: ProcessGroup,
|
|
):
|
|
"""Averages parameters of a model or parameter groups of an optimizer."""
|
|
average_parameters(iter(get_params_to_average(params)), process_group)
|