# Owner(s): ["oncall: distributed"] import contextlib import os import sys from typing import Any, Optional import torch import torch.distributed as dist if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.testing._internal.common_distributed import ( MultiProcessTestCase, require_n_gpus_for_nccl_backend, ) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN if TEST_WITH_DEV_DBG_ASAN: print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) sys.exit(0) BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) # Constants used for testing post-hooks BEFORE_CONSTANT = 41 AFTER_CONSTANT = 42 class AllReducerJoinHook(JoinHook): r""" Join hook for :class:`AllReducer`. Arguments: allreducer (AllReducer): the :class:`AllReducer` object using this hook. num_allreduces (int): the number of all-reduces to shadow per iteration. run_post_hook (bool): a flag enabling the post-hook logic. """ def __init__( self, allreducer, num_allreduces, run_post_hook ): self.allreducer = allreducer self.num_allreduces = num_allreduces self.run_post_hook = run_post_hook def main_hook(self): r""" Shadows each all-reduce; the number of all-reduces is passed into the constructor as ``num_allreduces``. """ device = self.allreducer.device for _ in range(self.num_allreduces): t = torch.zeros(1, device=device) dist.all_reduce(t) def post_hook(self, is_last_joiner: bool): r""" Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from the last joiner to all other processes. """ if not self.run_post_hook: return rank = dist.get_rank(self.allreducer.process_group) common_rank = self.allreducer.find_common_rank(rank, is_last_joiner) device = self.allreducer.device if rank == common_rank: self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device) dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank) class AllReducer(Joinable): r""" Example :class:`Joinable` that performs some number of all-reduces as its per-iteration collective communication. """ def __init__(self, device, process_group): super(AllReducer, self).__init__() self.device = device self.process_group = process_group self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device) def __call__(self, num_allreduces=1): r""" All-reduces a dim-1 one tensor ``num_allreduces``-many times, and returns the total result. """ Join.notify_join_context(self) device = self.device total = 0 for _ in range(num_allreduces): t = torch.ones(1, device=device) dist.all_reduce(t) total += t.item() return total def join_hook(self, **kwargs) -> JoinHook: r""" Returns a join hook that shadows some number of all-reduces; by default, this number is 1. """ num_allreduces = kwargs.get("num_allreduces", 1) run_post_hook = kwargs.get("run_post_hooks", False) return AllReducerJoinHook( self, num_allreduces, run_post_hook ) @property def join_device(self) -> torch.device: return self.device @property def join_process_group(self) -> Any: return self.process_group def find_common_rank(self, rank, to_consider): r""" Returns the max rank of the ones to consider over the process group. """ common_rank = torch.tensor( [rank if to_consider else -1], device=self.device ) dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group) common_rank = common_rank.item() assert common_rank >= 0 return common_rank class TestJoin(MultiProcessTestCase): r"""Test cases for the generic join context.""" def setUp(self): super(TestJoin, self).setUp() os.environ["WORLD_SIZE"] = str(self.world_size) os.environ["BACKEND"] = BACKEND self._spawn_processes() @property def device(self): return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \ else torch.device("cpu") @property def world_size(self): return WORLD_SIZE @property def process_group(self): return dist.group.WORLD def tearDown(self): try: dist.destroy_process_group() except AssertionError: pass try: os.remove(self.file_name) except OSError: pass def dist_init(self, rank, world_size, backend=BACKEND): store = dist.FileStore(self.file_name, world_size) return dist.init_process_group( backend=backend, store=store, rank=rank, world_size=world_size ) def construct_uneven_inputs(self, base, offset, device=None): r""" Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs. """ if device is None: device = self.device return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)] def construct_even_inputs(self, base, device=None): r"""Returns even inputs: each rank gets ``base`` inputs.""" if device is None: device = self.device return [torch.zeros(1, device=device) for _ in range(base)] @property def base_num_inputs(self): r"""Base number of inputs to be used by all ranks.""" return 3 @property def offset(self): r"""Rank i gets i * ``offset`` additional inputs.""" return 1 def _test_join_base( self, uneven_inputs: bool, num_joinables: int, enable: bool, throw_on_early_termination: bool, num_allreduces: int, run_post_hooks: bool, expected_total: Optional[int] = None, ): r""" Skeleton for all :class:`Join` tests. Arguments: uneven_inputs (bool): ``True`` to use uneven inputs; ``False`` otherwise. num_joinables (int): number of :class:`AllReducer` s to construct. enable (bool): ``True`` to enable the join context manager; ``False`` otherwise. throw_on_early_termination (bool): ``True`` to raise an exception upon detecting uneven inputs; ``False`` otherwise. num_allreduces (int): number of all-reduces to perform per input. run_post_hooks (bool): ``True`` to run post-hooks; ``False`` otherwise. expected_total (Optional[int]): ``None`` to not check the expected all-reduce total; otherwise, the expected total; default is ``None``. """ self.dist_init(self.rank, self.world_size) allreducers = [ AllReducer(self.device, self.process_group) for _ in range(num_joinables) ] for allreducer in allreducers: self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT) inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \ if uneven_inputs \ else self.construct_even_inputs(self.base_num_inputs) allreduce_total = 0 # Expect a `RuntimeError` if `throw_on_early_termination=True` # Rank 0 exhausts its inputs first expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \ else "Detected at least one rank that exhausted inputs. " \ "Throwing across all ranks." with self.assertRaisesRegex( RuntimeError, expected_msg ) if throw_on_early_termination else contextlib.suppress(): with Join( allreducers, enable=enable, throw_on_early_termination=throw_on_early_termination, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks ): for _ in inputs: for allreducer in allreducers: allreduce_total += allreducer(num_allreduces) if throw_on_early_termination: return # Check `expected_total` if not `None` if expected_total: self.assertEqual(allreduce_total, expected_total) # All `AllReduce` instances should receive the updated # `post_hook_tensor` from the last-joined process if run_post_hooks: for allreducer in allreducers: self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_single_joinable_main_hooks(self): r"""Tests the main hooks of a single :class:`Joinable`.""" num_joinables = 1 num_allreduces = 1 run_post_hooks = False # Non-joined processes all-reduce a 1, so this rank's all-reduce total # should be precisely equal to the total number of inputs processed # before it joined expected_total = self.world_size * self.base_num_inputs # Rank i runs for i additional iterations for num_joined in range(1, self.rank + 1): expected_total += (self.world_size - num_joined) * self.offset self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_single_joinable_post_hooks(self): r"""Tests the post-hooks of a single :class:`Joinable`.""" num_joinables = 1 num_allreduces = 0 # set to 0 to skip the main hooks run_post_hooks = False self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=None ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_single_joinable(self): r""" Tests the main hooks and post-hooks of a single :class:`Joinable` together. This combines ``test_single_joinable_main_hooks()`` and ``test_single_joinable_post_hooks()`` into a single test to ensure that main hooks and post-hooks operate correctly together. """ num_joinables = 1 num_allreduces = 1 run_post_hooks = True expected_total = self.world_size * self.base_num_inputs for num_joined in range(1, self.rank + 1): expected_total += (self.world_size - num_joined) * self.offset self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_multiple_joinables(self): r""" Tests the main hooks and post-hooks of multiple :class:`Joinable` s together. This generalizes ``test_single_joinable()`` to multiple :class:`Joinable` s. """ num_joinables = 3 num_allreduces = 1 run_post_hooks = True expected_total = self.world_size * self.base_num_inputs for num_joined in range(1, self.rank + 1): expected_total += (self.world_size - num_joined) * self.offset # The expected total is now multiplied by a factor of `NUM_JOINABLES` expected_total *= num_joinables self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_single_joinable_disable(self): r"""Tests ``enable=False`` for a single :class:`Joinable`.""" num_joinables = 1 num_allreduces = 1 uneven_inputs = False enable = False run_post_hooks = False expected_total = self.world_size * self.base_num_inputs self._test_join_base( uneven_inputs=uneven_inputs, num_joinables=num_joinables, enable=enable, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_multiple_joinable_disable(self): r""" Tests ``enable=False`` for multiple :class:`Joinable` s. This generalizes ``test_single_joinable_disable`` to multiple :class:`Joinable` s. """ num_joinables = 3 num_allreduces = 1 uneven_inputs = False enable = False run_post_hooks = False expected_total = self.world_size * self.base_num_inputs * num_joinables self._test_join_base( uneven_inputs=uneven_inputs, num_joinables=num_joinables, enable=enable, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_single_joinable_throw(self): r""" Tests ``throw_on_early_termination=True`` for a single :class:`Joinable`. """ num_joinables = 1 num_allreduces = 1 throw_on_early_termination = True run_post_hooks = False self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=throw_on_early_termination, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=None ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_multiple_joinables_throw(self): r""" Tests ``throw_on_early_termination=True`` for multiple :class:`Joinable` s together. This generalizes ``test_single_joinable_throw`` to multiple :class:`Joinable` s. """ num_joinables = 3 num_allreduces = 1 throw_on_early_termination = True run_post_hooks = False self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=throw_on_early_termination, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=None ) @require_n_gpus_for_nccl_backend( WORLD_SIZE, BACKEND ) def test_join_kwargs(self): r""" Tests passing keyword arguments to the context manager. """ num_joinables = 1 num_allreduces = 2 run_post_hooks = False expected_total = self.world_size * self.base_num_inputs for num_joined in range(1, self.rank + 1): expected_total += (self.world_size - num_joined) * self.offset # The expected total is now multiplied by a factor of `NUM_ALLREDUCES` expected_total *= num_allreduces self._test_join_base( uneven_inputs=True, num_joinables=num_joinables, enable=True, throw_on_early_termination=False, num_allreduces=num_allreduces, run_post_hooks=run_post_hooks, expected_total=expected_total ) if __name__ == "__main__": run_tests()