mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Model Averaging] Support hierarchical model averaging (#73285)
Summary: Implement hierarchical model averaging proposed in https://github.com/pytorch/pytorch/issues/71325. Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs. In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73285 Reviewed By: mrshenli Differential Revision: D34457792 Pulled By: rohan-varma fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba (cherry picked from commit 5f543d46103edb515db199dbb80db43c85665f29)
This commit is contained in:
4
LICENSE
4
LICENSE
@ -28,6 +28,10 @@ All rights reserved.
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All contributions by Cruise LLC:
|
||||
Copyright (c) 2022 Cruise LLC.
|
||||
All rights reserved.
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
@ -94,7 +94,7 @@ class PeriodicModelAverager(ModelAverager):
|
||||
warnings.warn(
|
||||
"When period is 1, no need to use model averaging because the communication cost "
|
||||
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
|
||||
"by DistributedDataParall in the backward pass. Therefore, only "
|
||||
"by DistributedDataParallel in the backward pass. Therefore, only "
|
||||
"DistributedDataParallel should be used for this case."
|
||||
)
|
||||
self.period = period
|
||||
|
@ -0,0 +1,159 @@
|
||||
# Copyright 2022 Cruise LLC
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
import logging
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.model_averaging.utils as utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HierarchicalModelAverager:
|
||||
r"""
|
||||
A group of model averagers used for hierarchical model averaging (hierarchical SGD).
|
||||
Process groups of different sizes are organized in a hierarhicy, and they average parameters
|
||||
by using different periods concurrently after the warm-up stage.
|
||||
This is an extension of :class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`
|
||||
that supports `post-local SGD <https://arxiv.org/abs/1808.07217>`_, which essentially only supports
|
||||
a two-level hierarchy: the intra-machine level and the global level, where the intra-machine
|
||||
level is usually embedded in :meth:`~torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook`.
|
||||
Similarly, the process groups within this class do not have such an intra-machine process
|
||||
subgroup, which should be embedded by the post-local SGD communication hook instead.
|
||||
|
||||
Args:
|
||||
period_group_size_dict: An ordered dict mapping keys of model averaging period to
|
||||
process group size, used for initializing process groups of
|
||||
different sizes in a hierarchy to average parameters concurrently.
|
||||
Particularly, at each iteration, there will be at most a single
|
||||
process group that runs averaging -- the period of such group should
|
||||
have the largest period which the current step can be divided by.
|
||||
For example, if the dict has three keys: 2, 4, and 8,
|
||||
then this means totally three process groups will be created to
|
||||
average parameters every 2, 4, and 8 iterations, respectively.
|
||||
At the 4th iteration, only the second process group will run
|
||||
averaging, because the first process group should be a
|
||||
subset of the second process group, and no need to execute the first
|
||||
process group redundantly.
|
||||
On the other hand, the third process group can only be triggered
|
||||
every 8 iterations, so it will not be triggered at the 4th iteration.
|
||||
warmup_steps (int): The number of warm-up steps. During this stage, model averaging is skipped.
|
||||
process_group (ProcessGroup, optional): The overall process group containing all the processes that runs model averaging.
|
||||
If ``None``, the default process group, which is created
|
||||
by :func:`torch.distributed.init_process_group`, will be used.
|
||||
(default: ``None``)
|
||||
|
||||
Example::
|
||||
>>> from collections import OrderedDict
|
||||
>>> import torch
|
||||
>>> import torch.distributed as dist
|
||||
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
|
||||
>>> PostLocalSGDState,
|
||||
>>> post_localSGD_hook,
|
||||
>>> )
|
||||
>>> import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
|
||||
>>> import torch.nn as nn
|
||||
>>>
|
||||
>>> dist.init_process_group("nccl", rank=rank, world_size=16)
|
||||
>>> torch.cuda.set_device(rank)
|
||||
>>> module = nn.Linear(1, 1, bias=False).to(rank)
|
||||
>>> model = nn.parallel.DistributedDataParallel(
|
||||
>>> module, device_ids=[rank], output_device=rank
|
||||
>>> )
|
||||
>>> # Register a post-localSGD communication hook.
|
||||
>>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4.
|
||||
>>> subgroup, _ = dist.new_subgroups()
|
||||
>>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
|
||||
>>> model.register_comm_hook(state, post_localSGD_hook)
|
||||
>>>
|
||||
>>> # Average parameters among each group of 8 processes every 4 iterations, and among all
|
||||
>>> # the 16 processes every 16 iterations.
|
||||
>>> averager = hierarchicalSGD.HierarchicalModelAverager(
|
||||
>>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100)
|
||||
>>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
|
||||
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
|
||||
>>> # After 100 steps, run model averaging at two levels.
|
||||
>>> for step in range(0, 200):
|
||||
>>> optimizer.zero_grad()
|
||||
>>> loss = loss_fn(output, labels)
|
||||
>>> loss.backward()
|
||||
>>> optimizer.step()
|
||||
>>> # Average parameters after ``optimizer.step()``.
|
||||
>>> # Thus, the inter-node communication only occurs periodically after ``warmup_steps``.
|
||||
>>> averager.average_parameters(model.parameters())
|
||||
|
||||
.. warning ::
|
||||
The last group size in the dict must be the size of the provided ``process_group``,
|
||||
which indicates model averaging at the highest level of the hierarchy.
|
||||
If ``process_group`` is not provided, then the last group size should be equal to the world size.
|
||||
|
||||
.. warning ::
|
||||
`HierarchicalModelAverager` is experimental and subject to change.
|
||||
"""
|
||||
|
||||
def __init__(self, period_group_size_dict=None, warmup_steps=0, process_group=None):
|
||||
if not period_group_size_dict:
|
||||
raise ValueError("Arg ``period_group_size_dict`` must not be empty.")
|
||||
self._periods = list(period_group_size_dict.keys())
|
||||
if self._periods[0] <= 0:
|
||||
raise ValueError("The minimum period in arg ``period_group_size_dict`` must be a positive value.")
|
||||
elif self._periods[-1] == 1:
|
||||
warnings.warn(
|
||||
"When the maximum period in arg ``period_group_size_dict`` is 1, "
|
||||
"no need to use model averaging because the communication cost "
|
||||
"of all-reducing parameters will be no less than the cost of all-reducing gradients "
|
||||
"by DistributedDataParallel in the backward pass. Therefore, only "
|
||||
"DistributedDataParallel should be used for this case."
|
||||
)
|
||||
ovall_group : dist.ProcessGroup = (
|
||||
process_group if process_group is not None else dist.group.WORLD
|
||||
)
|
||||
overall_group_size = dist.get_world_size(group=ovall_group)
|
||||
if list(period_group_size_dict.values())[-1] != overall_group_size:
|
||||
raise ValueError(
|
||||
"The last value in arg ``period_process_group_dict`` "
|
||||
"must be equal to the size of arg ``process_group``.")
|
||||
|
||||
self.period_process_group_dict = OrderedDict()
|
||||
logger.info("Model averaging hierarchy:")
|
||||
for period, group_size in period_group_size_dict.items():
|
||||
logger.info(
|
||||
f"\tEach group that has {group_size} processes average parameters every {period} iterations, "
|
||||
"if no higher-level averaging.")
|
||||
if group_size != overall_group_size:
|
||||
self.period_process_group_dict[period], _ = dist.new_subgroups(
|
||||
group_size=group_size, group=ovall_group)
|
||||
else:
|
||||
self.period_process_group_dict[period] = ovall_group
|
||||
|
||||
if warmup_steps < 0:
|
||||
raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
|
||||
self.warmup_steps = warmup_steps
|
||||
self.step = 0
|
||||
|
||||
def _find_process_group(self):
|
||||
"""
|
||||
Returns a tuple consisting of whether ``step`` can be divided by
|
||||
a period in the keys of ``period_process_group_dict`` and the associated process group if any.
|
||||
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
|
||||
then the returned process group is the one corresponding to the largest period,
|
||||
since this process group will be used for averaging parameters at this ``step``.
|
||||
"""
|
||||
for period in reversed(self._periods):
|
||||
if self.step % period == 0:
|
||||
return (True, self.period_process_group_dict[period])
|
||||
return (False, None)
|
||||
|
||||
def average_parameters(self, params):
|
||||
r"""
|
||||
Averages parameters if ``step`` is no less than ``warmup_steps``
|
||||
and it can be divided by a period in the keys of ``period_process_group_dict``,
|
||||
where ``step`` is increased by 1 at each iteration in the training loop.
|
||||
If ``step`` can be divided by multiple periods in the keys of ``period_process_group_dict``,
|
||||
only the largest period is used, and the corresponding process group is used for averaging parameters.
|
||||
"""
|
||||
if self.step >= self.warmup_steps:
|
||||
found, group = self._find_process_group()
|
||||
if found:
|
||||
utils.average_parameters(iter(params), group)
|
||||
self.step += 1
|
@ -6,7 +6,7 @@ import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, OrderedDict
|
||||
from contextlib import contextmanager, suppress
|
||||
from datetime import timedelta
|
||||
from functools import reduce
|
||||
@ -16,6 +16,7 @@ import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.model_averaging.averagers as averagers
|
||||
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
|
||||
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -1033,6 +1034,102 @@ class DistributedTest:
|
||||
# No model averaging, so the parameters are not updated.
|
||||
self.assertEqual(param.data, tensor)
|
||||
|
||||
@sandcastle_skip_if(
|
||||
BACKEND not in DistTestCases.backend_feature["subgroup"],
|
||||
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
|
||||
)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
|
||||
model = nn.Linear(1, 5, bias=False).cuda(device_id)
|
||||
param = next(model.parameters())
|
||||
tensor = torch.ones_like(param.data) * rank
|
||||
expected_avg_tensor = (
|
||||
torch.ones_like(param.data) * sum(range(world_size)) / world_size
|
||||
)
|
||||
period = 4
|
||||
for warmup_steps in [12, 13, 14, 15]:
|
||||
averager = hierarchicalSGD.HierarchicalModelAverager(
|
||||
# Run the global averaging at a period of 4,
|
||||
# which is equivalent to the above periodic model averaging test case.
|
||||
period_group_size_dict=OrderedDict([(period, world_size)]), warmup_steps=warmup_steps
|
||||
)
|
||||
|
||||
averager = averagers.PeriodicModelAverager(period=period, warmup_steps=warmup_steps)
|
||||
for step in range(0, 20):
|
||||
# Reset the parameters at every step.
|
||||
param.data = copy.deepcopy(tensor)
|
||||
averager.average_parameters(model.parameters())
|
||||
if step >= warmup_steps and (step - warmup_steps) % period == 0:
|
||||
self.assertEqual(param.data, expected_avg_tensor)
|
||||
else:
|
||||
# No model averaging, so the parameters are not updated.
|
||||
self.assertEqual(param.data, tensor)
|
||||
|
||||
@sandcastle_skip_if(
|
||||
BACKEND not in DistTestCases.backend_feature["subgroup"],
|
||||
f"The {BACKEND} backend does not support creating subgroups on CUDA devices"
|
||||
)
|
||||
@require_world_size(4)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_3_level_hierarchical_model_averager(self):
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
|
||||
device_id = rank_to_GPU[rank][0]
|
||||
|
||||
model = nn.Linear(1, 5, bias=False).cuda(device_id)
|
||||
param = next(model.parameters())
|
||||
tensor = torch.ones_like(param.data) * rank
|
||||
# Set up such a hierarchical model averaging as follows:
|
||||
# after the first 10 warmup steps,
|
||||
# run model averaging every 2 steps within each subgroup of size 2,
|
||||
# run model averaging every 4 steps within each subgroup of size 3,
|
||||
# and run the global model averaging every 8 steps.
|
||||
# If there is a conflict in model averaging at a step, only run the highest-level model averaging.
|
||||
warmup_steps = 10
|
||||
subgroup_size1 = 2
|
||||
subgroup_avg_period1 = 2
|
||||
subgroup_size2 = 4
|
||||
subgroup_avg_period2 = 4
|
||||
global_avg_period = 8
|
||||
period_group_size_dict = OrderedDict(
|
||||
[(subgroup_avg_period1, subgroup_size1),
|
||||
(subgroup_avg_period2, subgroup_size2),
|
||||
(global_avg_period, world_size)])
|
||||
averager = hierarchicalSGD.HierarchicalModelAverager(
|
||||
period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
|
||||
)
|
||||
expected_avg_tensor_within_subgroup1 = (
|
||||
torch.ones_like(param.data) * sum(range(subgroup_size1)) / subgroup_size1
|
||||
)
|
||||
expected_avg_tensor_within_subgroup2 = (
|
||||
torch.ones_like(param.data) * sum(range(subgroup_size2)) / subgroup_size2
|
||||
)
|
||||
expected_global_avg_tensor = (
|
||||
torch.ones_like(param.data) * sum(range(world_size)) / world_size
|
||||
)
|
||||
for step in range(0, 25):
|
||||
# Reset the parameters at every step.
|
||||
param.data = copy.deepcopy(tensor)
|
||||
averager.average_parameters(model.parameters())
|
||||
if step == 16 or step == 24:
|
||||
# Run global model averaging when `step` can be divided by 8.
|
||||
self.assertEqual(param.data, expected_global_avg_tensor)
|
||||
elif step == 12 or step == 20:
|
||||
# Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
|
||||
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
|
||||
elif step == 10 or step == 14 or step == 18 or step == 22:
|
||||
# Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
|
||||
self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
|
||||
else:
|
||||
# No model averaging, so the parameters are not updated.
|
||||
self.assertEqual(param.data, tensor)
|
||||
|
||||
# NCCL Batch SEND RECV
|
||||
@skip_if_no_gpu
|
||||
@sandcastle_skip_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
|
||||
|
Reference in New Issue
Block a user