mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Action based on https://github.com/pytorch/pytorch/issues/66232 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/66797 Reviewed By: gchanan Differential Revision: D31761389 Pulled By: janeyx99 fbshipit-source-id: c27c9ab4acec1eb71d5edd4538cd113b770dfc6c
513 lines
16 KiB
Python
513 lines
16 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import contextlib
|
|
import os
|
|
import sys
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
require_n_gpus_for_nccl_backend,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
|
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
|
|
|
# Constants used for testing post-hooks
|
|
BEFORE_CONSTANT = 41
|
|
AFTER_CONSTANT = 42
|
|
|
|
|
|
class AllReducerJoinHook(JoinHook):
|
|
r"""
|
|
Join hook for :class:`AllReducer`.
|
|
|
|
Arguments:
|
|
allreducer (AllReducer): the :class:`AllReducer` object using this
|
|
hook.
|
|
num_allreduces (int): the number of all-reduces to shadow per
|
|
iteration.
|
|
run_post_hook (bool): a flag enabling the post-hook logic.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
allreducer,
|
|
num_allreduces,
|
|
run_post_hook
|
|
):
|
|
self.allreducer = allreducer
|
|
self.num_allreduces = num_allreduces
|
|
self.run_post_hook = run_post_hook
|
|
|
|
def main_hook(self):
|
|
r"""
|
|
Shadows each all-reduce; the number of all-reduces is passed into the
|
|
constructor as ``num_allreduces``.
|
|
"""
|
|
device = self.allreducer.device
|
|
for _ in range(self.num_allreduces):
|
|
t = torch.zeros(1, device=device)
|
|
dist.all_reduce(t)
|
|
|
|
def post_hook(self, is_last_joiner: bool):
|
|
r"""
|
|
Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from
|
|
the last joiner to all other processes.
|
|
"""
|
|
if not self.run_post_hook:
|
|
return
|
|
rank = dist.get_rank(self.allreducer.process_group)
|
|
common_rank = self.allreducer.find_common_rank(rank, is_last_joiner)
|
|
device = self.allreducer.device
|
|
if rank == common_rank:
|
|
self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device)
|
|
dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank)
|
|
|
|
|
|
class AllReducer(Joinable):
|
|
r"""
|
|
Example :class:`Joinable` that performs some number of all-reduces as its
|
|
per-iteration collective communication.
|
|
"""
|
|
def __init__(self, device, process_group):
|
|
super(AllReducer, self).__init__()
|
|
self.device = device
|
|
self.process_group = process_group
|
|
self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device)
|
|
|
|
def __call__(self, num_allreduces=1):
|
|
r"""
|
|
All-reduces a dim-1 one tensor ``num_allreduces``-many times, and
|
|
returns the total result.
|
|
"""
|
|
Join.notify_join_context(self)
|
|
device = self.device
|
|
total = 0
|
|
for _ in range(num_allreduces):
|
|
t = torch.ones(1, device=device)
|
|
dist.all_reduce(t)
|
|
total += t.item()
|
|
return total
|
|
|
|
def join_hook(self, **kwargs) -> JoinHook:
|
|
r"""
|
|
Returns a join hook that shadows some number of all-reduces; by default,
|
|
this number is 1.
|
|
"""
|
|
num_allreduces = kwargs.get("num_allreduces", 1)
|
|
run_post_hook = kwargs.get("run_post_hooks", False)
|
|
return AllReducerJoinHook(
|
|
self,
|
|
num_allreduces,
|
|
run_post_hook
|
|
)
|
|
|
|
@property
|
|
def join_device(self) -> torch.device:
|
|
return self.device
|
|
|
|
@property
|
|
def join_process_group(self) -> Any:
|
|
return self.process_group
|
|
|
|
def find_common_rank(self, rank, to_consider):
|
|
r"""
|
|
Returns the max rank of the ones to consider over the process group.
|
|
"""
|
|
common_rank = torch.tensor(
|
|
[rank if to_consider else -1],
|
|
device=self.device
|
|
)
|
|
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
|
|
common_rank = common_rank.item()
|
|
assert common_rank >= 0
|
|
return common_rank
|
|
|
|
class TestJoin(MultiProcessTestCase):
|
|
r"""Test cases for the generic join context."""
|
|
def setUp(self):
|
|
super(TestJoin, self).setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = BACKEND
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
|
|
else torch.device("cpu")
|
|
|
|
@property
|
|
def world_size(self):
|
|
return WORLD_SIZE
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def tearDown(self):
|
|
try:
|
|
dist.destroy_process_group()
|
|
except AssertionError:
|
|
pass
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
def dist_init(self, rank, world_size, backend=BACKEND):
|
|
store = dist.FileStore(self.file_name, world_size)
|
|
return dist.init_process_group(
|
|
backend=backend,
|
|
store=store,
|
|
rank=rank,
|
|
world_size=world_size
|
|
)
|
|
|
|
def construct_uneven_inputs(self, base, offset, device=None):
|
|
r"""
|
|
Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs.
|
|
"""
|
|
if device is None:
|
|
device = self.device
|
|
return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)]
|
|
|
|
def construct_even_inputs(self, base, device=None):
|
|
r"""Returns even inputs: each rank gets ``base`` inputs."""
|
|
if device is None:
|
|
device = self.device
|
|
return [torch.zeros(1, device=device) for _ in range(base)]
|
|
|
|
@property
|
|
def base_num_inputs(self):
|
|
r"""Base number of inputs to be used by all ranks."""
|
|
return 3
|
|
|
|
@property
|
|
def offset(self):
|
|
r"""Rank i gets i * ``offset`` additional inputs."""
|
|
return 1
|
|
|
|
def _test_join_base(
|
|
self,
|
|
uneven_inputs: bool,
|
|
num_joinables: int,
|
|
enable: bool,
|
|
throw_on_early_termination: bool,
|
|
num_allreduces: int,
|
|
run_post_hooks: bool,
|
|
expected_total: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Skeleton for all :class:`Join` tests.
|
|
|
|
Arguments:
|
|
uneven_inputs (bool): ``True`` to use uneven inputs; ``False``
|
|
otherwise.
|
|
num_joinables (int): number of :class:`AllReducer` s to construct.
|
|
enable (bool): ``True`` to enable the join context manager;
|
|
``False`` otherwise.
|
|
throw_on_early_termination (bool): ``True`` to raise an exception
|
|
upon detecting uneven inputs; ``False`` otherwise.
|
|
num_allreduces (int): number of all-reduces to perform per input.
|
|
run_post_hooks (bool): ``True`` to run post-hooks; ``False``
|
|
otherwise.
|
|
expected_total (Optional[int]): ``None`` to not check the expected
|
|
all-reduce total; otherwise, the expected total; default is
|
|
``None``.
|
|
"""
|
|
self.dist_init(self.rank, self.world_size)
|
|
|
|
allreducers = [
|
|
AllReducer(self.device, self.process_group)
|
|
for _ in range(num_joinables)
|
|
]
|
|
for allreducer in allreducers:
|
|
self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT)
|
|
|
|
inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \
|
|
if uneven_inputs \
|
|
else self.construct_even_inputs(self.base_num_inputs)
|
|
allreduce_total = 0
|
|
|
|
# Expect a `RuntimeError` if `throw_on_early_termination=True`
|
|
# Rank 0 exhausts its inputs first
|
|
expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \
|
|
else "Detected at least one rank that exhausted inputs. " \
|
|
"Throwing across all ranks."
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
expected_msg
|
|
) if throw_on_early_termination else contextlib.suppress():
|
|
with Join(
|
|
allreducers,
|
|
enable=enable,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks
|
|
):
|
|
for _ in inputs:
|
|
for allreducer in allreducers:
|
|
allreduce_total += allreducer(num_allreduces)
|
|
|
|
if throw_on_early_termination:
|
|
return
|
|
|
|
# Check `expected_total` if not `None`
|
|
if expected_total:
|
|
self.assertEqual(allreduce_total, expected_total)
|
|
|
|
# All `AllReduce` instances should receive the updated
|
|
# `post_hook_tensor` from the last-joined process
|
|
if run_post_hooks:
|
|
for allreducer in allreducers:
|
|
self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_main_hooks(self):
|
|
r"""Tests the main hooks of a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
run_post_hooks = False
|
|
# Non-joined processes all-reduce a 1, so this rank's all-reduce total
|
|
# should be precisely equal to the total number of inputs processed
|
|
# before it joined
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
# Rank i runs for i additional iterations
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_post_hooks(self):
|
|
r"""Tests the post-hooks of a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 0 # set to 0 to skip the main hooks
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable(self):
|
|
r"""
|
|
Tests the main hooks and post-hooks of a single :class:`Joinable`
|
|
together.
|
|
|
|
This combines ``test_single_joinable_main_hooks()`` and
|
|
``test_single_joinable_post_hooks()`` into a single test to ensure that
|
|
main hooks and post-hooks operate correctly together.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
run_post_hooks = True
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinables(self):
|
|
r"""
|
|
Tests the main hooks and post-hooks of multiple :class:`Joinable` s
|
|
together.
|
|
|
|
This generalizes ``test_single_joinable()`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
run_post_hooks = True
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
# The expected total is now multiplied by a factor of `NUM_JOINABLES`
|
|
expected_total *= num_joinables
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_disable(self):
|
|
r"""Tests ``enable=False`` for a single :class:`Joinable`."""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
uneven_inputs = False
|
|
enable = False
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=uneven_inputs,
|
|
num_joinables=num_joinables,
|
|
enable=enable,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinable_disable(self):
|
|
r"""
|
|
Tests ``enable=False`` for multiple :class:`Joinable` s.
|
|
|
|
This generalizes ``test_single_joinable_disable`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
uneven_inputs = False
|
|
enable = False
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs * num_joinables
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=uneven_inputs,
|
|
num_joinables=num_joinables,
|
|
enable=enable,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_single_joinable_throw(self):
|
|
r"""
|
|
Tests ``throw_on_early_termination=True`` for a single
|
|
:class:`Joinable`.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 1
|
|
throw_on_early_termination = True
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_multiple_joinables_throw(self):
|
|
r"""
|
|
Tests ``throw_on_early_termination=True`` for multiple
|
|
:class:`Joinable` s together.
|
|
|
|
This generalizes ``test_single_joinable_throw`` to multiple
|
|
:class:`Joinable` s.
|
|
"""
|
|
num_joinables = 3
|
|
num_allreduces = 1
|
|
throw_on_early_termination = True
|
|
run_post_hooks = False
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=throw_on_early_termination,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=None
|
|
)
|
|
|
|
@require_n_gpus_for_nccl_backend(
|
|
WORLD_SIZE, BACKEND
|
|
)
|
|
def test_join_kwargs(self):
|
|
r"""
|
|
Tests passing keyword arguments to the context manager.
|
|
"""
|
|
num_joinables = 1
|
|
num_allreduces = 2
|
|
run_post_hooks = False
|
|
|
|
expected_total = self.world_size * self.base_num_inputs
|
|
for num_joined in range(1, self.rank + 1):
|
|
expected_total += (self.world_size - num_joined) * self.offset
|
|
# The expected total is now multiplied by a factor of `NUM_ALLREDUCES`
|
|
expected_total *= num_allreduces
|
|
|
|
self._test_join_base(
|
|
uneven_inputs=True,
|
|
num_joinables=num_joinables,
|
|
enable=True,
|
|
throw_on_early_termination=False,
|
|
num_allreduces=num_allreduces,
|
|
run_post_hooks=run_post_hooks,
|
|
expected_total=expected_total
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|