Refactoring Distributed test cases to be device agnostic [1/n] (#145222)

In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic.

These changes will include the following approaches to do the same :

- Allowing for multiple device types using instantiate_device_type_test
- Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies
- Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable
- Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536).

This should result in significant improvement in usability for all devices

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145222
Approved by: https://github.com/kwen2501
This commit is contained in:
Anant Gulati
2025-02-05 18:47:09 +00:00
committed by PyTorch MergeBot
parent 6f7fda3f49
commit 9091096d6c

View File

@ -2,10 +2,9 @@
import contextlib
import functools
import os
import unittest
from copy import deepcopy
from typing import Callable, Optional
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
@ -27,17 +26,20 @@ from torch.distributed.tensor.parallel import (
)
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
DistributedTestBase,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils.checkpoint import checkpoint
device_type = str(get_devtype())
DIM = 2000
@ -72,7 +74,7 @@ def compiler_fn(no_inductor=False):
return _compiler_fn
class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase):
"""
A version of MultiProcessTestCase that derives from the Inductor TestCase
to handle isolation of the inductor cache dir.
@ -80,46 +82,21 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
class ReplicateTest(MultiProcessInductorTestCase):
# TODO: consider using all devices? The min(2, ...) here would limit the
# test to always run on 2 GPUs only.
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
return min(2, torch.get_device_module(device_type).device_count())
def _test_compile(
self,
*,
use_gpu: bool,
no_sync: bool,
setup_func: Optional[Callable] = None,
no_inductor: bool = False,
no_compile_forward: bool = False,
checkpoint: bool = False,
device: Union[str, torch.device],
):
backend = "nccl" if use_gpu else "gloo"
dist.init_process_group(
backend=backend,
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
if use_gpu:
torch.cuda.set_device(f"cuda:{self.rank}")
device = torch.device("cuda")
else:
device = torch.device("cpu")
self.create_pg(device)
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
if no_compile_forward
@ -202,6 +179,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
self.assertEqual(
tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
)
dist.destroy_process_group()
def test_compile_cpu(self):
# Test the coalesced_op with CPU.
@ -209,7 +187,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
"fuse_ddp_with_coalesced_op",
"schedule_comm_wait",
]
self._test_compile(use_gpu=False, no_sync=False)
self._test_compile(no_sync=False, device="cpu")
def test_compile_cpu_no_sync(self):
# Test the coalesced_op with CPU.
@ -217,7 +195,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
"fuse_ddp_with_coalesced_op",
"schedule_comm_wait",
]
self._test_compile(use_gpu=False, no_sync=True)
self._test_compile(no_sync=True, device="cpu")
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@ -226,7 +204,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
reorder_for_locality=False, reorder_for_peak_memory=False
)
def test_compile_gpu(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
self._test_compile(no_sync=False, checkpoint=False, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@ -235,15 +213,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
reorder_for_locality=False, reorder_for_peak_memory=False
)
def test_compile_gpu_ac(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
self._test_compile(no_sync=False, checkpoint=True, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_bf16(self):
# Check device capability wrt bf16
device = torch.device("cuda", self.rank % torch.cuda.device_count())
if not sm_is_or_higher_than(device, 8, 0):
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
@ -254,7 +231,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
None, ddp_default_hooks.bf16_compress_hook
)
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
self._test_compile(no_sync=False, setup_func=setup, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@ -270,14 +247,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
# TODO: figure out why we need to disable Inductor to avoid test errors.
self._test_compile(
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
no_sync=False, setup_func=setup, no_inductor=True, device=device_type
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2)
def test_compile_backward_only(self):
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
def _test_bucketing(self, init_process_group=True, loop=1):
if init_process_group:
@ -397,7 +374,7 @@ class DDP_TP_Test(InductorTestCase):
# Hmm, why a specific set_device call for rank 0?
self.rank = 0
self.world_size = 4
torch.cuda.set_device("cuda:0")
torch.get_device_module(device_type).set_device(device_type)
store = FakeStore()
dist.init_process_group(
@ -419,7 +396,7 @@ class DDP_TP_Test(InductorTestCase):
ref_model = Net()
compiled_replicate_model = deepcopy(ref_model)
mesh_2d = init_device_mesh(
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]