From 9091096d6c42f9c032a0f5eaba0a3253900be643 Mon Sep 17 00:00:00 2001 From: Anant Gulati Date: Wed, 5 Feb 2025 18:47:09 +0000 Subject: [PATCH] Refactoring Distributed test cases to be device agnostic [1/n] (#145222) In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic. These changes will include the following approaches to do the same : - Allowing for multiple device types using instantiate_device_type_test - Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies - Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable - Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536). This should result in significant improvement in usability for all devices Pull Request resolved: https://github.com/pytorch/pytorch/pull/145222 Approved by: https://github.com/kwen2501 --- .../test_replicate_with_compiler.py | 63 ++++++------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 91c3ecb47981..b945162a5109 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -2,10 +2,9 @@ import contextlib import functools -import os import unittest from copy import deepcopy -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -27,17 +26,20 @@ from torch.distributed.tensor.parallel import ( ) from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, + DistributedTestBase, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, ) +from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils.checkpoint import checkpoint +device_type = str(get_devtype()) + DIM = 2000 @@ -72,7 +74,7 @@ def compiler_fn(no_inductor=False): return _compiler_fn -class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase): +class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase): """ A version of MultiProcessTestCase that derives from the Inductor TestCase to handle isolation of the inductor cache dir. @@ -80,46 +82,21 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase): class ReplicateTest(MultiProcessInductorTestCase): - # TODO: consider using all devices? The min(2, ...) here would limit the - # test to always run on 2 GPUs only. @property def world_size(self) -> int: - return min(2, torch.cuda.device_count()) - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - def tearDown(self): - super().tearDown() - try: - os.remove(self.file_name) - except OSError: - pass + return min(2, torch.get_device_module(device_type).device_count()) def _test_compile( self, *, - use_gpu: bool, no_sync: bool, setup_func: Optional[Callable] = None, no_inductor: bool = False, no_compile_forward: bool = False, checkpoint: bool = False, + device: Union[str, torch.device], ): - backend = "nccl" if use_gpu else "gloo" - dist.init_process_group( - backend=backend, - rank=self.rank, - world_size=self.world_size, - store=dist.FileStore(self.file_name, self.world_size), - ) - if use_gpu: - torch.cuda.set_device(f"cuda:{self.rank}") - device = torch.device("cuda") - else: - device = torch.device("cpu") - + self.create_pg(device) torch._dynamo.config.optimize_ddp = ( "python_reducer_without_compiled_forward" if no_compile_forward @@ -202,6 +179,7 @@ class ReplicateTest(MultiProcessInductorTestCase): self.assertEqual( tuple(model.parameters()), tuple(compiled_ddp_model.parameters()) ) + dist.destroy_process_group() def test_compile_cpu(self): # Test the coalesced_op with CPU. @@ -209,7 +187,7 @@ class ReplicateTest(MultiProcessInductorTestCase): "fuse_ddp_with_coalesced_op", "schedule_comm_wait", ] - self._test_compile(use_gpu=False, no_sync=False) + self._test_compile(no_sync=False, device="cpu") def test_compile_cpu_no_sync(self): # Test the coalesced_op with CPU. @@ -217,7 +195,7 @@ class ReplicateTest(MultiProcessInductorTestCase): "fuse_ddp_with_coalesced_op", "schedule_comm_wait", ] - self._test_compile(use_gpu=False, no_sync=True) + self._test_compile(no_sync=True, device="cpu") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @@ -226,7 +204,7 @@ class ReplicateTest(MultiProcessInductorTestCase): reorder_for_locality=False, reorder_for_peak_memory=False ) def test_compile_gpu(self): - self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) + self._test_compile(no_sync=False, checkpoint=False, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @@ -235,15 +213,14 @@ class ReplicateTest(MultiProcessInductorTestCase): reorder_for_locality=False, reorder_for_peak_memory=False ) def test_compile_gpu_ac(self): - self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) + self._test_compile(no_sync=False, checkpoint=True, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): # Check device capability wrt bf16 - device = torch.device("cuda", self.rank % torch.cuda.device_count()) - if not sm_is_or_higher_than(device, 8, 0): + if not sm_is_or_higher_than(torch.device(device_type), 8, 0): self.skipTest("bf16 requires sm >= 8.0") def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -254,7 +231,7 @@ class ReplicateTest(MultiProcessInductorTestCase): None, ddp_default_hooks.bf16_compress_hook ) - self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) + self._test_compile(no_sync=False, setup_func=setup, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @@ -270,14 +247,14 @@ class ReplicateTest(MultiProcessInductorTestCase): # TODO: figure out why we need to disable Inductor to avoid test errors. self._test_compile( - use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True + no_sync=False, setup_func=setup, no_inductor=True, device=device_type ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): - self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True) + self._test_compile(no_sync=False, no_compile_forward=True, device=device_type) def _test_bucketing(self, init_process_group=True, loop=1): if init_process_group: @@ -397,7 +374,7 @@ class DDP_TP_Test(InductorTestCase): # Hmm, why a specific set_device call for rank 0? self.rank = 0 self.world_size = 4 - torch.cuda.set_device("cuda:0") + torch.get_device_module(device_type).set_device(device_type) store = FakeStore() dist.init_process_group( @@ -419,7 +396,7 @@ class DDP_TP_Test(InductorTestCase): ref_model = Net() compiled_replicate_model = deepcopy(ref_model) mesh_2d = init_device_mesh( - "cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") + device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp") ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"]