mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactoring Distributed test cases to be device agnostic [1/n] (#145222)
In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic. These changes will include the following approaches to do the same : - Allowing for multiple device types using instantiate_device_type_test - Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies - Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable - Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536). This should result in significant improvement in usability for all devices Pull Request resolved: https://github.com/pytorch/pytorch/pull/145222 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
6f7fda3f49
commit
9091096d6c
@ -2,10 +2,9 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -27,17 +26,20 @@ from torch.distributed.tensor.parallel import (
|
||||
)
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
DistributedTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
)
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
|
||||
DIM = 2000
|
||||
|
||||
|
||||
@ -72,7 +74,7 @@ def compiler_fn(no_inductor=False):
|
||||
return _compiler_fn
|
||||
|
||||
|
||||
class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
|
||||
class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase):
|
||||
"""
|
||||
A version of MultiProcessTestCase that derives from the Inductor TestCase
|
||||
to handle isolation of the inductor cache dir.
|
||||
@ -80,46 +82,21 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
|
||||
|
||||
|
||||
class ReplicateTest(MultiProcessInductorTestCase):
|
||||
# TODO: consider using all devices? The min(2, ...) here would limit the
|
||||
# test to always run on 2 GPUs only.
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return min(2, torch.cuda.device_count())
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
return min(2, torch.get_device_module(device_type).device_count())
|
||||
|
||||
def _test_compile(
|
||||
self,
|
||||
*,
|
||||
use_gpu: bool,
|
||||
no_sync: bool,
|
||||
setup_func: Optional[Callable] = None,
|
||||
no_inductor: bool = False,
|
||||
no_compile_forward: bool = False,
|
||||
checkpoint: bool = False,
|
||||
device: Union[str, torch.device],
|
||||
):
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
if use_gpu:
|
||||
torch.cuda.set_device(f"cuda:{self.rank}")
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
self.create_pg(device)
|
||||
torch._dynamo.config.optimize_ddp = (
|
||||
"python_reducer_without_compiled_forward"
|
||||
if no_compile_forward
|
||||
@ -202,6 +179,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
self.assertEqual(
|
||||
tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_compile_cpu(self):
|
||||
# Test the coalesced_op with CPU.
|
||||
@ -209,7 +187,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
"fuse_ddp_with_coalesced_op",
|
||||
"schedule_comm_wait",
|
||||
]
|
||||
self._test_compile(use_gpu=False, no_sync=False)
|
||||
self._test_compile(no_sync=False, device="cpu")
|
||||
|
||||
def test_compile_cpu_no_sync(self):
|
||||
# Test the coalesced_op with CPU.
|
||||
@ -217,7 +195,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
"fuse_ddp_with_coalesced_op",
|
||||
"schedule_comm_wait",
|
||||
]
|
||||
self._test_compile(use_gpu=False, no_sync=True)
|
||||
self._test_compile(no_sync=True, device="cpu")
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@ -226,7 +204,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||
)
|
||||
def test_compile_gpu(self):
|
||||
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
|
||||
self._test_compile(no_sync=False, checkpoint=False, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@ -235,15 +213,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||
)
|
||||
def test_compile_gpu_ac(self):
|
||||
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
|
||||
self._test_compile(no_sync=False, checkpoint=True, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_bf16(self):
|
||||
# Check device capability wrt bf16
|
||||
device = torch.device("cuda", self.rank % torch.cuda.device_count())
|
||||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
|
||||
@ -254,7 +231,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
None, ddp_default_hooks.bf16_compress_hook
|
||||
)
|
||||
|
||||
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
|
||||
self._test_compile(no_sync=False, setup_func=setup, device=device_type)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@ -270,14 +247,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
||||
|
||||
# TODO: figure out why we need to disable Inductor to avoid test errors.
|
||||
self._test_compile(
|
||||
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
|
||||
no_sync=False, setup_func=setup, no_inductor=True, device=device_type
|
||||
)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_backward_only(self):
|
||||
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
|
||||
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
|
||||
|
||||
def _test_bucketing(self, init_process_group=True, loop=1):
|
||||
if init_process_group:
|
||||
@ -397,7 +374,7 @@ class DDP_TP_Test(InductorTestCase):
|
||||
# Hmm, why a specific set_device call for rank 0?
|
||||
self.rank = 0
|
||||
self.world_size = 4
|
||||
torch.cuda.set_device("cuda:0")
|
||||
torch.get_device_module(device_type).set_device(device_type)
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
@ -419,7 +396,7 @@ class DDP_TP_Test(InductorTestCase):
|
||||
ref_model = Net()
|
||||
compiled_replicate_model = deepcopy(ref_model)
|
||||
mesh_2d = init_device_mesh(
|
||||
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
tp_mesh = mesh_2d["tp"]
|
||||
dp_mesh = mesh_2d["dp"]
|
||||
|
Reference in New Issue
Block a user