Refactoring Distributed test cases to be device agnostic [1/n] (#145222)

In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic.

These changes will include the following approaches to do the same :

- Allowing for multiple device types using instantiate_device_type_test
- Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies
- Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable
- Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536).

This should result in significant improvement in usability for all devices

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145222
Approved by: https://github.com/kwen2501
This commit is contained in:
Anant Gulati
2025-02-05 18:47:09 +00:00
committed by PyTorch MergeBot
parent 6f7fda3f49
commit 9091096d6c

View File

@ -2,10 +2,9 @@
import contextlib import contextlib
import functools import functools
import os
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -27,17 +26,20 @@ from torch.distributed.tensor.parallel import (
) )
from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import ( from torch.testing._internal.common_distributed import (
MultiProcessTestCase, DistributedTestBase,
skip_if_lt_x_gpu, skip_if_lt_x_gpu,
skip_if_rocm_multiprocess, skip_if_rocm_multiprocess,
sm_is_or_higher_than, sm_is_or_higher_than,
) )
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.inductor_utils import HAS_GPU
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
device_type = str(get_devtype())
DIM = 2000 DIM = 2000
@ -72,7 +74,7 @@ def compiler_fn(no_inductor=False):
return _compiler_fn return _compiler_fn
class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase): class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase):
""" """
A version of MultiProcessTestCase that derives from the Inductor TestCase A version of MultiProcessTestCase that derives from the Inductor TestCase
to handle isolation of the inductor cache dir. to handle isolation of the inductor cache dir.
@ -80,46 +82,21 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
class ReplicateTest(MultiProcessInductorTestCase): class ReplicateTest(MultiProcessInductorTestCase):
# TODO: consider using all devices? The min(2, ...) here would limit the
# test to always run on 2 GPUs only.
@property @property
def world_size(self) -> int: def world_size(self) -> int:
return min(2, torch.cuda.device_count()) return min(2, torch.get_device_module(device_type).device_count())
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _test_compile( def _test_compile(
self, self,
*, *,
use_gpu: bool,
no_sync: bool, no_sync: bool,
setup_func: Optional[Callable] = None, setup_func: Optional[Callable] = None,
no_inductor: bool = False, no_inductor: bool = False,
no_compile_forward: bool = False, no_compile_forward: bool = False,
checkpoint: bool = False, checkpoint: bool = False,
device: Union[str, torch.device],
): ):
backend = "nccl" if use_gpu else "gloo" self.create_pg(device)
dist.init_process_group(
backend=backend,
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
if use_gpu:
torch.cuda.set_device(f"cuda:{self.rank}")
device = torch.device("cuda")
else:
device = torch.device("cpu")
torch._dynamo.config.optimize_ddp = ( torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward" "python_reducer_without_compiled_forward"
if no_compile_forward if no_compile_forward
@ -202,6 +179,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
self.assertEqual( self.assertEqual(
tuple(model.parameters()), tuple(compiled_ddp_model.parameters()) tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
) )
dist.destroy_process_group()
def test_compile_cpu(self): def test_compile_cpu(self):
# Test the coalesced_op with CPU. # Test the coalesced_op with CPU.
@ -209,7 +187,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
"fuse_ddp_with_coalesced_op", "fuse_ddp_with_coalesced_op",
"schedule_comm_wait", "schedule_comm_wait",
] ]
self._test_compile(use_gpu=False, no_sync=False) self._test_compile(no_sync=False, device="cpu")
def test_compile_cpu_no_sync(self): def test_compile_cpu_no_sync(self):
# Test the coalesced_op with CPU. # Test the coalesced_op with CPU.
@ -217,7 +195,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
"fuse_ddp_with_coalesced_op", "fuse_ddp_with_coalesced_op",
"schedule_comm_wait", "schedule_comm_wait",
] ]
self._test_compile(use_gpu=False, no_sync=True) self._test_compile(no_sync=True, device="cpu")
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@ -226,7 +204,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
reorder_for_locality=False, reorder_for_peak_memory=False reorder_for_locality=False, reorder_for_peak_memory=False
) )
def test_compile_gpu(self): def test_compile_gpu(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) self._test_compile(no_sync=False, checkpoint=False, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@ -235,15 +213,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
reorder_for_locality=False, reorder_for_peak_memory=False reorder_for_locality=False, reorder_for_peak_memory=False
) )
def test_compile_gpu_ac(self): def test_compile_gpu_ac(self):
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) self._test_compile(no_sync=False, checkpoint=True, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_compile_bf16(self): def test_compile_bf16(self):
# Check device capability wrt bf16 # Check device capability wrt bf16
device = torch.device("cuda", self.rank % torch.cuda.device_count()) if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0") self.skipTest("bf16 requires sm >= 8.0")
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
@ -254,7 +231,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
None, ddp_default_hooks.bf16_compress_hook None, ddp_default_hooks.bf16_compress_hook
) )
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup) self._test_compile(no_sync=False, setup_func=setup, device=device_type)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@ -270,14 +247,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
# TODO: figure out why we need to disable Inductor to avoid test errors. # TODO: figure out why we need to disable Inductor to avoid test errors.
self._test_compile( self._test_compile(
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True no_sync=False, setup_func=setup, no_inductor=True, device=device_type
) )
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm_multiprocess @skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(2) @skip_if_lt_x_gpu(2)
def test_compile_backward_only(self): def test_compile_backward_only(self):
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True) self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
def _test_bucketing(self, init_process_group=True, loop=1): def _test_bucketing(self, init_process_group=True, loop=1):
if init_process_group: if init_process_group:
@ -397,7 +374,7 @@ class DDP_TP_Test(InductorTestCase):
# Hmm, why a specific set_device call for rank 0? # Hmm, why a specific set_device call for rank 0?
self.rank = 0 self.rank = 0
self.world_size = 4 self.world_size = 4
torch.cuda.set_device("cuda:0") torch.get_device_module(device_type).set_device(device_type)
store = FakeStore() store = FakeStore()
dist.init_process_group( dist.init_process_group(
@ -419,7 +396,7 @@ class DDP_TP_Test(InductorTestCase):
ref_model = Net() ref_model = Net()
compiled_replicate_model = deepcopy(ref_model) compiled_replicate_model = deepcopy(ref_model)
mesh_2d = init_device_mesh( mesh_2d = init_device_mesh(
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp") device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
) )
tp_mesh = mesh_2d["tp"] tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"] dp_mesh = mesh_2d["dp"]