mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Enable several distributed UTs (#164390)
Increase the tolerance for the following UTs as there was a slight mismatch seen on MI200. - test_data_parallel.py:test_strided_grad_layout - test_c10d_nccl.py:test_grad_layout_1devicemodule_1replicaperprocess Skip for MI200: - test_fully_shard_training.py:test_2d_mlp_with_nd_mesh - test_2d_composability.py:test_train_parity_2d_mlp - test_fully_shard_overlap.py:test_fully_shard_training_overlap Fixes #159489 Fixes #159488 Fixes #152700 Fixes #125555 Fixes #134139 Working as is on both MI200 and MI300: Fixes #125991 Fixes #125918 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164390 Approved by: https://github.com/jeffdaily
This commit is contained in:
@ -10,14 +10,22 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.distributed.tensor.experimental import implicit_replication
|
from torch.distributed.tensor.experimental import implicit_replication
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import (
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
skip_if_rocm_arch_multiprocess,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
FSDPTest,
|
FSDPTest,
|
||||||
get_devtype,
|
get_devtype,
|
||||||
patch_all_gather,
|
patch_all_gather,
|
||||||
patch_reduce_scatter,
|
patch_reduce_scatter,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU
|
from torch.testing._internal.common_utils import (
|
||||||
|
get_cycles_per_ms,
|
||||||
|
MI200_ARCH,
|
||||||
|
run_tests,
|
||||||
|
TEST_HPU,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
device_type = torch.device(get_devtype())
|
device_type = torch.device(get_devtype())
|
||||||
@ -43,6 +51,7 @@ class TestFullyShardOverlap(FSDPTest):
|
|||||||
def world_size(self) -> int:
|
def world_size(self) -> int:
|
||||||
return min(2, torch.get_device_module(device_type).device_count())
|
return min(2, torch.get_device_module(device_type).device_count())
|
||||||
|
|
||||||
|
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
||||||
def test_fully_shard_training_overlap(self):
|
def test_fully_shard_training_overlap(self):
|
||||||
|
@ -27,7 +27,10 @@ from torch.distributed.fsdp import (
|
|||||||
)
|
)
|
||||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import (
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
skip_if_rocm_arch_multiprocess,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
check_sharded_parity,
|
check_sharded_parity,
|
||||||
compiled_fsdp_test,
|
compiled_fsdp_test,
|
||||||
@ -40,6 +43,7 @@ from torch.testing._internal.common_fsdp import (
|
|||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
get_cycles_per_ms,
|
get_cycles_per_ms,
|
||||||
|
MI200_ARCH,
|
||||||
run_tests,
|
run_tests,
|
||||||
TEST_HPU,
|
TEST_HPU,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
@ -1198,6 +1202,7 @@ class TestFullyShardNDTraining(FSDPTest):
|
|||||||
mesh_dim_names=("pp", "dp", "tp"),
|
mesh_dim_names=("pp", "dp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
|
||||||
@skip_if_lt_x_gpu(4)
|
@skip_if_lt_x_gpu(4)
|
||||||
def test_2d_mlp_with_nd_mesh(self):
|
def test_2d_mlp_with_nd_mesh(self):
|
||||||
global_mesh = self.init_global_mesh()
|
global_mesh = self.init_global_mesh()
|
||||||
|
@ -41,10 +41,14 @@ from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
|
|||||||
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
|
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
|
||||||
from torch.distributed.tensor.parallel.input_reshard import input_reshard
|
from torch.distributed.tensor.parallel.input_reshard import input_reshard
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import (
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
skip_if_rocm_arch_multiprocess,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP, MLPStack
|
from torch.testing._internal.common_fsdp import FSDPTest, MLP, MLPStack
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
|
MI200_ARCH,
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
TEST_XPU,
|
TEST_XPU,
|
||||||
@ -121,6 +125,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||||||
mesh_dim_names=("dp", "tp"),
|
mesh_dim_names=("dp", "tp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
|
||||||
@skip_if_lt_x_gpu(2)
|
@skip_if_lt_x_gpu(2)
|
||||||
def test_train_parity_2d_mlp(self):
|
def test_train_parity_2d_mlp(self):
|
||||||
global_mesh = self.init_global_mesh()
|
global_mesh = self.init_global_mesh()
|
||||||
|
@ -2072,7 +2072,7 @@ class DistributedDataParallelTest(
|
|||||||
opt = torch.optim.SGD(m.parameters(), lr=0.1)
|
opt = torch.optim.SGD(m.parameters(), lr=0.1)
|
||||||
opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1)
|
opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1)
|
||||||
has_half = any(p.dtype is torch.half for p in m.parameters())
|
has_half = any(p.dtype is torch.half for p in m.parameters())
|
||||||
tol = 1.0e-3 if has_half else 1.0e-5
|
tol = 3.0e-3 if has_half else 1.0e-5
|
||||||
except BaseException:
|
except BaseException:
|
||||||
# Prints case-specific debugging info to narrow down failing case.
|
# Prints case-specific debugging info to narrow down failing case.
|
||||||
print(
|
print(
|
||||||
|
@ -760,7 +760,7 @@ class TestDataParallel(TestCase):
|
|||||||
opt = torch.optim.SGD(m.parameters(), lr=0.1)
|
opt = torch.optim.SGD(m.parameters(), lr=0.1)
|
||||||
opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
|
opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
|
||||||
has_half = any(p.dtype is torch.half for p in m.parameters())
|
has_half = any(p.dtype is torch.half for p in m.parameters())
|
||||||
tol = 1.0e-3 if has_half else 1.0e-5
|
tol = 3.0e-3 if has_half else 1.0e-5
|
||||||
except BaseException:
|
except BaseException:
|
||||||
# Prints case-specific debugging info to narrow down failing case.
|
# Prints case-specific debugging info to narrow down failing case.
|
||||||
print(
|
print(
|
||||||
|
@ -444,11 +444,11 @@ def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
|
|||||||
"""Skips a test for given ROCm archs - multiprocess UTs"""
|
"""Skips a test for given ROCm archs - multiprocess UTs"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
|
||||||
arch_match = prop in arch
|
|
||||||
reason = None
|
reason = None
|
||||||
if TEST_WITH_ROCM and arch_match:
|
if TEST_WITH_ROCM:
|
||||||
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
|
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
||||||
|
if prop in arch:
|
||||||
|
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
|
||||||
|
|
||||||
return unittest.skipIf(reason is not None, reason)(func)
|
return unittest.skipIf(reason is not None, reason)(func)
|
||||||
|
|
||||||
|
@ -104,6 +104,9 @@ except ImportError:
|
|||||||
SEED = 1234
|
SEED = 1234
|
||||||
MI300_ARCH = ("gfx942",)
|
MI300_ARCH = ("gfx942",)
|
||||||
MI200_ARCH = ("gfx90a")
|
MI200_ARCH = ("gfx90a")
|
||||||
|
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
|
||||||
|
NAVI3_ARCH = ("gfx1100", "gfx1101")
|
||||||
|
NAVI4_ARCH = ("gfx1200", "gfx1201")
|
||||||
|
|
||||||
class ProfilingMode(Enum):
|
class ProfilingMode(Enum):
|
||||||
LEGACY = 1
|
LEGACY = 1
|
||||||
|
Reference in New Issue
Block a user