# Owner(s): ["oncall: distributed"] import os import sys import tempfile import threading from functools import partial, wraps import torch import torch.distributed as dist import torch.distributed._hooks as dhooks if not dist.is_available(): print("torch.distributed not available, skipping tests", file=sys.stderr) sys.exit(0) from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import run_tests, TestCase class PgHooks(MultiProcessTestCase): @property def world_size(self) -> int: return 4 def setUp(self) -> None: super().setUp() self._spawn_processes() def tearDown(self): super().tearDown() try: os.remove(self.file_name) except OSError: pass def test_pg_hook(self): pgs = [] def pg_hook(pg, pg_name): pgs.append((pg, pg_name)) dhooks.register_process_group_hook(pg_hook) dist.init_process_group( backend="gloo", rank=self.rank, world_size=self.world_size, store=dist.FileStore(self.file_name, self.world_size), ) self.assertEqual(len(pgs), 1) self.assertEqual(pgs[0][0], dist.group.WORLD) # create two partial world PGs pg0 = dist.new_group(ranks=[0, 1]) pg1 = dist.new_group(ranks=[2, 3]) # Each rank only observe two PGs being created: the default PG and one covering its ranks # We don't emit events for PG creation if the current rank doesn't belong to it. # For example, say you're rank 1, you'll get an event for pg0 but not pg1 even though the API contact # dictates you need to call new_group for both. self.assertEqual(len(pgs), 2) self.assertEqual(pgs[1][0], pg0 if self.rank < 2 else pg1) def with_comms(func=None): if func is None: return partial( with_comms, ) @wraps(func) def wrapper(self, *args, **kwargs): self.init_comms() func(self, *args, **kwargs) self.destroy_comms() return wrapper class CollectiveHooks: @property def world_size(self) -> int: return 4 def _collective_hooks(self): # it's ok to access them directly since there's a single bg thread poking at them. starts = [] ends = [] cv = threading.Condition() def coll_start(status): starts.append(status) print(f"col_start {len(starts)} rank{self.rank}") def coll_end(status): ends.append(status) print(f"col_end {len(ends)} rank{self.rank}") if len(ends) == 2: with cv: cv.notify() dhooks.register_collective_start_hook(coll_start) dhooks.register_collective_end_hook(coll_end) tensor = torch.ones([2, 3]).to(self.device) * self.rank tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)] dist.all_gather(tensor_list, tensor) tensor2 = torch.ones([2, 3]).to(self.device) * self.rank dist.all_reduce(tensor2) with cv: cv.wait(1) default_pg_name = dist.group.WORLD.group_name self.assertEqual(2, len(starts)) self.assertEqual(2, len(ends)) def check_op(idx, coll_name): self.assertEqual(default_pg_name, starts[idx].pg_name) self.assertEqual(self.backend_name, starts[idx].backend) self.assertGreaterEqual(starts[idx].sequence_number, 0) self.assertGreaterEqual(starts[idx].timestamp, 0) self.assertEqual(coll_name, starts[idx].operation) self.assertEqual(default_pg_name, ends[idx].pg_name) self.assertEqual(self.backend_name, ends[idx].backend) self.assertEqual(starts[idx].sequence_number, ends[idx].sequence_number) self.assertLessEqual(starts[idx].timestamp, ends[idx].timestamp) self.assertEqual(coll_name, ends[idx].operation) check_op(0, "ALLGATHER") check_op(1, "ALLREDUCE") class GlooHooks(MultiProcessTestCase, CollectiveHooks): def setUp(self) -> None: super().setUp() self._spawn_processes() def tearDown(self): super().tearDown() try: os.remove(self.file_name) except OSError: pass def init_comms(self): dist.init_process_group( backend="gloo", rank=self.rank, world_size=self.world_size, store=dist.FileStore(self.file_name, self.world_size), ) def destroy_comms(self): dist.destroy_process_group() @property def backend_name(self): return "gloo" @property def device(self): return "cpu" @with_comms def test_collective_hooks(self): self._collective_hooks() class NcclHooks(MultiProcessTestCase, CollectiveHooks): def setUp(self) -> None: super().setUp() self._spawn_processes() def tearDown(self): super().tearDown() try: os.remove(self.file_name) except OSError: pass def init_comms(self): dist.init_process_group( backend="nccl", rank=self.rank, world_size=self.world_size, store=dist.FileStore(self.file_name, self.world_size), ) def destroy_comms(self): dist.destroy_process_group() @property def backend_name(self): return "nccl" @property def device(self): return f"cuda:{self.rank}" @skip_if_lt_x_gpu(4) @with_comms def test_collective_hooks(self): self._collective_hooks() class SingleRankTests(TestCase): def setUp(self) -> None: super().setUp() self.rank = 0 self.file_name = tempfile.NamedTemporaryFile(delete=False).name dist.init_process_group( backend="gloo", rank=0, world_size=1, store=dist.FileStore(self.file_name, 1), ) def tearDown(self) -> None: dist.destroy_process_group() def test_queue_overflow(self) -> None: cv_done_colls = threading.Condition() cv_done_cb = threading.Condition() colls_done = False starts = [] status_with_dropped = None def coll_start(status: dhooks.CollectiveStatus): starts.append(status) with cv_done_colls: while not colls_done: cv_done_colls.wait() if status.drop_count > 0: nonlocal status_with_dropped status_with_dropped = status with cv_done_cb: cv_done_cb.notify() dhooks.register_collective_start_hook(coll_start) # native limit is 512 for i in range(600): dist.all_reduce(torch.ones([2, 3])) colls_done = True with cv_done_colls: cv_done_colls.notify() with cv_done_cb: cv_done_cb.wait(10) self.assertTrue(status_with_dropped is not None) self.assertTrue(status_with_dropped.drop_count > 0) if __name__ == "__main__": assert ( not torch.cuda._initialized ), "test_distributed must not have initialized CUDA context on main process" run_tests()