diff --git a/test/distributed/test_c10d_object_collectives.py b/test/distributed/test_c10d_object_collectives.py index dcd6de797e72..594564c45606 100644 --- a/test/distributed/test_c10d_object_collectives.py +++ b/test/distributed/test_c10d_object_collectives.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import os import sys from functools import partial, wraps @@ -12,8 +11,15 @@ if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) -from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS -from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_distributed import DistributedTestBase, TEST_SKIPS +from torch.testing._internal.common_utils import ( + run_tests, + skipIfHpu, + TEST_CUDA, + TEST_HPU, + TEST_WITH_DEV_DBG_ASAN, +) if TEST_WITH_DEV_DBG_ASAN: @@ -23,7 +29,16 @@ if TEST_WITH_DEV_DBG_ASAN: ) sys.exit(0) -BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO +if TEST_HPU: + DEVICE = "hpu" +elif TEST_CUDA: + DEVICE = "cuda" +else: + DEVICE = "cpu" + +device_module = torch.get_device_module(DEVICE) +device_count = device_module.device_count() +BACKEND = dist.get_default_backend_for_device(DEVICE) def with_comms(func=None): @@ -34,59 +49,22 @@ def with_comms(func=None): @wraps(func) def wrapper(self, *args, **kwargs): - if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size: + if DEVICE != "cpu" and device_count < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) - self.dist_init() - func(self) - self.destroy_comms() + + kwargs["device"] = DEVICE + self.pg = self.create_pg(device=DEVICE) + try: + return func(self, *args, **kwargs) + finally: + torch.distributed.destroy_process_group() return wrapper -class TestObjectCollectives(MultiProcessTestCase): - def setUp(self): - super().setUp() - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["BACKEND"] = BACKEND - self._spawn_processes() - - @property - def device(self): - return ( - torch.device("cuda", self.rank % torch.cuda.device_count()) - if BACKEND == dist.Backend.NCCL - else torch.device("cpu") - ) - - @property - def world_size(self): - if BACKEND == dist.Backend.NCCL: - return torch.cuda.device_count() - return super().world_size - - @property - def process_group(self): - return dist.group.WORLD - - def destroy_comms(self): - # Wait for all ranks to reach here before starting shutdown. - dist.barrier() - dist.destroy_process_group() - - def dist_init(self): - dist.init_process_group( - backend=BACKEND, - world_size=self.world_size, - rank=self.rank, - init_method=f"file://{self.file_name}", - ) - - # set device for nccl pg for collectives - if BACKEND == "nccl": - torch.cuda.set_device(self.rank) - +class TestObjectCollectives(DistributedTestBase): @with_comms() - def test_all_gather_object(self): + def test_all_gather_object(self, device): output = [None] * dist.get_world_size() dist.all_gather_object(object_list=output, obj=self.rank) @@ -94,7 +72,7 @@ class TestObjectCollectives(MultiProcessTestCase): self.assertEqual(i, v, f"rank: {self.rank}") @with_comms() - def test_gather_object(self): + def test_gather_object(self, device): output = [None] * dist.get_world_size() if self.rank == 0 else None dist.gather_object(obj=self.rank, object_gather_list=output) @@ -102,8 +80,9 @@ class TestObjectCollectives(MultiProcessTestCase): for i, v in enumerate(output): self.assertEqual(i, v, f"rank: {self.rank}") + @skipIfHpu @with_comms() - def test_send_recv_object_list(self): + def test_send_recv_object_list(self, device): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() if self.rank == 0: @@ -117,7 +96,7 @@ class TestObjectCollectives(MultiProcessTestCase): self.assertEqual(None, object_list[0]) @with_comms() - def test_broadcast_object_list(self): + def test_broadcast_object_list(self, device): val = 99 if self.rank == 0 else None object_list = [val] * dist.get_world_size() # TODO test with broadcast_object_list's device argument @@ -126,7 +105,7 @@ class TestObjectCollectives(MultiProcessTestCase): self.assertEqual(99, object_list[0]) @with_comms() - def test_scatter_object_list(self): + def test_scatter_object_list(self, device): input_list = list(range(dist.get_world_size())) if self.rank == 0 else None output_list = [None] dist.scatter_object_list( @@ -144,30 +123,34 @@ class TestObjectCollectives(MultiProcessTestCase): my_pg = dist.new_group(ranks, use_local_synchronization=True) return rank, ranks, my_pg + @skipIfHpu @with_comms() - def test_subpg_scatter_object(self): + def test_subpg_scatter_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg) self.assertEqual(rank, out_list[0]) + @skipIfHpu @with_comms() - def test_subpg_all_gather_object(self): + def test_subpg_all_gather_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) dist.all_gather_object(out_list, rank, group=my_pg) self.assertEqual(ranks, out_list) + @skipIfHpu @with_comms() - def test_subpg_gather_object(self): + def test_subpg_gather_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] * len(ranks) if rank == ranks[0] else None dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg) if rank == ranks[0]: self.assertEqual(ranks, out_list) + @skipIfHpu @with_comms() - def test_subpg_broadcast_object(self): + def test_subpg_broadcast_object(self, device): rank, ranks, my_pg = self.setup_sub_pg() out_list = [None] if rank == ranks[0]: @@ -176,5 +159,7 @@ class TestObjectCollectives(MultiProcessTestCase): self.assertEqual(ranks[0], out_list[0]) +devices = ("cpu", "cuda", "hpu") +instantiate_device_type_tests(TestObjectCollectives, globals(), only_for=devices) if __name__ == "__main__": run_tests()