mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Implement hierarchical model averaging proposed in https://github.com/pytorch/pytorch/issues/71325. Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs. In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73285 Reviewed By: mrshenli Differential Revision: D34457792 Pulled By: rohan-varma fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba (cherry picked from commit 5f543d46103edb515db199dbb80db43c85665f29)
114 lines
4.7 KiB
Python
114 lines
4.7 KiB
Python
import warnings
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed.algorithms.model_averaging.utils as utils
|
|
|
|
|
|
class ModelAverager(ABC):
|
|
r"""Base class for all model averagers.
|
|
|
|
Args:
|
|
process_group: The process group to be used for all-reduce.
|
|
If ``None``, the default process group, which
|
|
is created by :func:`torch.distributed.init_process_group`,
|
|
will be used. (default: ``None``)
|
|
"""
|
|
|
|
def __init__(self, process_group=None):
|
|
self.process_group = (
|
|
process_group if process_group is not None else dist.group.WORLD
|
|
)
|
|
self.step = 0
|
|
|
|
@abstractmethod
|
|
def average_parameters(self, params):
|
|
raise NotImplementedError
|
|
|
|
|
|
class PeriodicModelAverager(ModelAverager):
|
|
r"""
|
|
Averages parameters periodically after the warm-up stage.
|
|
|
|
This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
|
|
by running :class:`~torch.nn.DistributedDataParallel` (DDP)
|
|
using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
|
|
|
|
Args:
|
|
period (int): The number of steps per model averaging.
|
|
Usually the period should be greater than ``1`` to reduce the communication cost.
|
|
Otherwise, only DDP needs to be used.
|
|
warmup_steps (int): The number of warm-up steps. During this stage,
|
|
model averaging is skipped.
|
|
process_group: The process group to be used for all-reduce.
|
|
If ``None``, the default process group, which
|
|
is created by :func:`torch.distributed.init_process_group`,
|
|
will be used. (default: ``None``)
|
|
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> import torch.distributed as dist
|
|
>>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
|
|
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
|
|
>>> import torch.nn as nn
|
|
>>>
|
|
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
|
|
>>> torch.cuda.set_device(rank)
|
|
>>> module = nn.Linear(1, 1, bias=False).cuda()
|
|
>>> model = nn.parallel.DistributedDataParallel(
|
|
>>> module, device_ids=[rank], output_device=rank
|
|
>>> )
|
|
>>> # Register a post-localSGD communication hook.
|
|
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
|
|
>>> model.register_comm_hook(state, post_localSGD_hook)
|
|
>>>
|
|
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
|
|
>>> # After 100 steps, run model averaging every 4 steps.
|
|
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
|
|
>>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
|
|
>>> for step in range(0, 200):
|
|
>>> optimizer.zero_grad()
|
|
>>> loss = loss_fn(output, labels)
|
|
>>> loss.backward()
|
|
>>> optimizer.step()
|
|
>>> # Will average model parameters globally every 4 steps. Thus,
|
|
>>> # inter-node communication only occurs every 4 iterations after
|
|
>>> # the initial ``warmup_steps`` period.
|
|
>>> averager.average_parameters(model.parameters())
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
period,
|
|
warmup_steps=0,
|
|
process_group=None,
|
|
):
|
|
super().__init__(process_group)
|
|
if warmup_steps < 0:
|
|
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
|
|
self.warmup_steps = warmup_steps
|
|
if period < 1:
|
|
raise ValueError("Arg ``period`` must be a positive value.")
|
|
elif period == 1:
|
|
warnings.warn(
|
|
"When period is 1, no need to use model averaging because the communication cost "
|
|
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
|
|
"by DistributedDataParallel in the backward pass. Therefore, only "
|
|
"DistributedDataParallel should be used for this case."
|
|
)
|
|
self.period = period
|
|
|
|
def average_parameters(self, params):
|
|
r"""
|
|
Averages parameters if ``step`` is no less than ``warmup_steps``
|
|
and it can be divided by ``period``, where ``step`` is increased by 1
|
|
at each iteration in the training loop.
|
|
"""
|
|
if (
|
|
self.step >= self.warmup_steps
|
|
and (self.step - self.warmup_steps) % self.period == 0
|
|
):
|
|
utils.average_parameters(iter(params), self.process_group)
|
|
self.step += 1
|