mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactoring Distributed test cases to be device agnostic [1/n] (#145222)
In this series of PR we intend to refactoring distributed test cases to enable to be completely device agnostic. These changes will include the following approaches to do the same : - Allowing for multiple device types using instantiate_device_type_test - Replacing calls to cuda stream with torch.get_device_module(device) wherever it applies - Skipping set up steps required while using MultiProcessTestCase with DistributedTestBase (#138216) wherever applicable - Replacing explicit calls to distributed backend (NCCL,HCCL,etc) with get_default_backend_for_device (#140536). This should result in significant improvement in usability for all devices Pull Request resolved: https://github.com/pytorch/pytorch/pull/145222 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
6f7fda3f49
commit
9091096d6c
@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -27,17 +26,20 @@ from torch.distributed.tensor.parallel import (
|
|||||||
)
|
)
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
MultiProcessTestCase,
|
DistributedTestBase,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
skip_if_rocm_multiprocess,
|
skip_if_rocm_multiprocess,
|
||||||
sm_is_or_higher_than,
|
sm_is_or_higher_than,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_fsdp import get_devtype
|
||||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
device_type = str(get_devtype())
|
||||||
|
|
||||||
DIM = 2000
|
DIM = 2000
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +74,7 @@ def compiler_fn(no_inductor=False):
|
|||||||
return _compiler_fn
|
return _compiler_fn
|
||||||
|
|
||||||
|
|
||||||
class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
|
class MultiProcessInductorTestCase(DistributedTestBase, InductorTestCase):
|
||||||
"""
|
"""
|
||||||
A version of MultiProcessTestCase that derives from the Inductor TestCase
|
A version of MultiProcessTestCase that derives from the Inductor TestCase
|
||||||
to handle isolation of the inductor cache dir.
|
to handle isolation of the inductor cache dir.
|
||||||
@ -80,46 +82,21 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase):
|
|||||||
|
|
||||||
|
|
||||||
class ReplicateTest(MultiProcessInductorTestCase):
|
class ReplicateTest(MultiProcessInductorTestCase):
|
||||||
# TODO: consider using all devices? The min(2, ...) here would limit the
|
|
||||||
# test to always run on 2 GPUs only.
|
|
||||||
@property
|
@property
|
||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
return min(2, torch.cuda.device_count())
|
return min(2, torch.get_device_module(device_type).device_count())
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
super().setUp()
|
|
||||||
self._spawn_processes()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
super().tearDown()
|
|
||||||
try:
|
|
||||||
os.remove(self.file_name)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _test_compile(
|
def _test_compile(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
use_gpu: bool,
|
|
||||||
no_sync: bool,
|
no_sync: bool,
|
||||||
setup_func: Optional[Callable] = None,
|
setup_func: Optional[Callable] = None,
|
||||||
no_inductor: bool = False,
|
no_inductor: bool = False,
|
||||||
no_compile_forward: bool = False,
|
no_compile_forward: bool = False,
|
||||||
checkpoint: bool = False,
|
checkpoint: bool = False,
|
||||||
|
device: Union[str, torch.device],
|
||||||
):
|
):
|
||||||
backend = "nccl" if use_gpu else "gloo"
|
self.create_pg(device)
|
||||||
dist.init_process_group(
|
|
||||||
backend=backend,
|
|
||||||
rank=self.rank,
|
|
||||||
world_size=self.world_size,
|
|
||||||
store=dist.FileStore(self.file_name, self.world_size),
|
|
||||||
)
|
|
||||||
if use_gpu:
|
|
||||||
torch.cuda.set_device(f"cuda:{self.rank}")
|
|
||||||
device = torch.device("cuda")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
torch._dynamo.config.optimize_ddp = (
|
torch._dynamo.config.optimize_ddp = (
|
||||||
"python_reducer_without_compiled_forward"
|
"python_reducer_without_compiled_forward"
|
||||||
if no_compile_forward
|
if no_compile_forward
|
||||||
@ -202,6 +179,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
|
tuple(model.parameters()), tuple(compiled_ddp_model.parameters())
|
||||||
)
|
)
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
def test_compile_cpu(self):
|
def test_compile_cpu(self):
|
||||||
# Test the coalesced_op with CPU.
|
# Test the coalesced_op with CPU.
|
||||||
@ -209,7 +187,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
"fuse_ddp_with_coalesced_op",
|
"fuse_ddp_with_coalesced_op",
|
||||||
"schedule_comm_wait",
|
"schedule_comm_wait",
|
||||||
]
|
]
|
||||||
self._test_compile(use_gpu=False, no_sync=False)
|
self._test_compile(no_sync=False, device="cpu")
|
||||||
|
|
||||||
def test_compile_cpu_no_sync(self):
|
def test_compile_cpu_no_sync(self):
|
||||||
# Test the coalesced_op with CPU.
|
# Test the coalesced_op with CPU.
|
||||||
@ -217,7 +195,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
"fuse_ddp_with_coalesced_op",
|
"fuse_ddp_with_coalesced_op",
|
||||||
"schedule_comm_wait",
|
"schedule_comm_wait",
|
||||||
]
|
]
|
||||||
self._test_compile(use_gpu=False, no_sync=True)
|
self._test_compile(no_sync=True, device="cpu")
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@ -226,7 +204,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||||
)
|
)
|
||||||
def test_compile_gpu(self):
|
def test_compile_gpu(self):
|
||||||
self._test_compile(use_gpu=True, no_sync=False, checkpoint=False)
|
self._test_compile(no_sync=False, checkpoint=False, device=device_type)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@ -235,15 +213,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
reorder_for_locality=False, reorder_for_peak_memory=False
|
reorder_for_locality=False, reorder_for_peak_memory=False
|
||||||
)
|
)
|
||||||
def test_compile_gpu_ac(self):
|
def test_compile_gpu_ac(self):
|
||||||
self._test_compile(use_gpu=True, no_sync=False, checkpoint=True)
|
self._test_compile(no_sync=False, checkpoint=True, device=device_type)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_compile_bf16(self):
|
def test_compile_bf16(self):
|
||||||
# Check device capability wrt bf16
|
# Check device capability wrt bf16
|
||||||
device = torch.device("cuda", self.rank % torch.cuda.device_count())
|
if not sm_is_or_higher_than(torch.device(device_type), 8, 0):
|
||||||
if not sm_is_or_higher_than(device, 8, 0):
|
|
||||||
self.skipTest("bf16 requires sm >= 8.0")
|
self.skipTest("bf16 requires sm >= 8.0")
|
||||||
|
|
||||||
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
|
def setup(model, compiled_replicate_model, compiled_ddp_model) -> None:
|
||||||
@ -254,7 +231,7 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
None, ddp_default_hooks.bf16_compress_hook
|
None, ddp_default_hooks.bf16_compress_hook
|
||||||
)
|
)
|
||||||
|
|
||||||
self._test_compile(use_gpu=True, no_sync=False, setup_func=setup)
|
self._test_compile(no_sync=False, setup_func=setup, device=device_type)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@ -270,14 +247,14 @@ class ReplicateTest(MultiProcessInductorTestCase):
|
|||||||
|
|
||||||
# TODO: figure out why we need to disable Inductor to avoid test errors.
|
# TODO: figure out why we need to disable Inductor to avoid test errors.
|
||||||
self._test_compile(
|
self._test_compile(
|
||||||
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
|
no_sync=False, setup_func=setup, no_inductor=True, device=device_type
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||||
@skip_if_rocm_multiprocess
|
@skip_if_rocm_multiprocess
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_compile_backward_only(self):
|
def test_compile_backward_only(self):
|
||||||
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
|
self._test_compile(no_sync=False, no_compile_forward=True, device=device_type)
|
||||||
|
|
||||||
def _test_bucketing(self, init_process_group=True, loop=1):
|
def _test_bucketing(self, init_process_group=True, loop=1):
|
||||||
if init_process_group:
|
if init_process_group:
|
||||||
@ -397,7 +374,7 @@ class DDP_TP_Test(InductorTestCase):
|
|||||||
# Hmm, why a specific set_device call for rank 0?
|
# Hmm, why a specific set_device call for rank 0?
|
||||||
self.rank = 0
|
self.rank = 0
|
||||||
self.world_size = 4
|
self.world_size = 4
|
||||||
torch.cuda.set_device("cuda:0")
|
torch.get_device_module(device_type).set_device(device_type)
|
||||||
|
|
||||||
store = FakeStore()
|
store = FakeStore()
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
@ -419,7 +396,7 @@ class DDP_TP_Test(InductorTestCase):
|
|||||||
ref_model = Net()
|
ref_model = Net()
|
||||||
compiled_replicate_model = deepcopy(ref_model)
|
compiled_replicate_model = deepcopy(ref_model)
|
||||||
mesh_2d = init_device_mesh(
|
mesh_2d = init_device_mesh(
|
||||||
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||||
)
|
)
|
||||||
tp_mesh = mesh_2d["tp"]
|
tp_mesh = mesh_2d["tp"]
|
||||||
dp_mesh = mesh_2d["dp"]
|
dp_mesh = mesh_2d["dp"]
|
||||||
|
Reference in New Issue
Block a user