[ROCm] Enable several distributed UTs (#164390)

Increase the tolerance for the following UTs as there was a slight mismatch seen on MI200.
    - test_data_parallel.py:test_strided_grad_layout
    - test_c10d_nccl.py:test_grad_layout_1devicemodule_1replicaperprocess

Skip for MI200:
    - test_fully_shard_training.py:test_2d_mlp_with_nd_mesh
    - test_2d_composability.py:test_train_parity_2d_mlp
    - test_fully_shard_overlap.py:test_fully_shard_training_overlap

Fixes #159489
Fixes #159488
Fixes #152700
Fixes #125555
Fixes #134139

Working as is on both MI200 and MI300:
Fixes #125991
Fixes #125918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164390
Approved by: https://github.com/jeffdaily
This commit is contained in:
Prachi
2025-10-03 19:52:51 +00:00
committed by PyTorch MergeBot
parent 1bb68271b7
commit 3ca09d65f1
7 changed files with 32 additions and 10 deletions

View File

@ -10,14 +10,22 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
skip_if_rocm_arch_multiprocess,
)
from torch.testing._internal.common_fsdp import (
FSDPTest,
get_devtype,
patch_all_gather,
patch_reduce_scatter,
)
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
MI200_ARCH,
run_tests,
TEST_HPU,
)
device_type = torch.device(get_devtype())
@ -43,6 +51,7 @@ class TestFullyShardOverlap(FSDPTest):
def world_size(self) -> int:
return min(2, torch.get_device_module(device_type).device_count())
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
def test_fully_shard_training_overlap(self):

View File

@ -27,7 +27,10 @@ from torch.distributed.fsdp import (
)
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
skip_if_rocm_arch_multiprocess,
)
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
compiled_fsdp_test,
@ -40,6 +43,7 @@ from torch.testing._internal.common_fsdp import (
)
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
MI200_ARCH,
run_tests,
TEST_HPU,
TEST_XPU,
@ -1198,6 +1202,7 @@ class TestFullyShardNDTraining(FSDPTest):
mesh_dim_names=("pp", "dp", "tp"),
)
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
@skip_if_lt_x_gpu(4)
def test_2d_mlp_with_nd_mesh(self):
global_mesh = self.init_global_mesh()

View File

@ -41,10 +41,14 @@ from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
from torch.distributed.tensor.parallel.input_reshard import input_reshard
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
skip_if_rocm_arch_multiprocess,
)
from torch.testing._internal.common_fsdp import FSDPTest, MLP, MLPStack
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
MI200_ARCH,
parametrize,
run_tests,
TEST_XPU,
@ -121,6 +125,7 @@ class TestFullyShard2DTraining(FSDPTest):
mesh_dim_names=("dp", "tp"),
)
@skip_if_rocm_arch_multiprocess(MI200_ARCH)
@skip_if_lt_x_gpu(2)
def test_train_parity_2d_mlp(self):
global_mesh = self.init_global_mesh()

View File

@ -2072,7 +2072,7 @@ class DistributedDataParallelTest(
opt = torch.optim.SGD(m.parameters(), lr=0.1)
opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1)
has_half = any(p.dtype is torch.half for p in m.parameters())
tol = 1.0e-3 if has_half else 1.0e-5
tol = 3.0e-3 if has_half else 1.0e-5
except BaseException:
# Prints case-specific debugging info to narrow down failing case.
print(

View File

@ -760,7 +760,7 @@ class TestDataParallel(TestCase):
opt = torch.optim.SGD(m.parameters(), lr=0.1)
opt_dp = torch.optim.SGD(m_dp.parameters(), lr=0.1)
has_half = any(p.dtype is torch.half for p in m.parameters())
tol = 1.0e-3 if has_half else 1.0e-5
tol = 3.0e-3 if has_half else 1.0e-5
except BaseException:
# Prints case-specific debugging info to narrow down failing case.
print(

View File

@ -444,11 +444,11 @@ def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
"""Skips a test for given ROCm archs - multiprocess UTs"""
def decorator(func):
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
arch_match = prop in arch
reason = None
if TEST_WITH_ROCM and arch_match:
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
if prop in arch:
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
return unittest.skipIf(reason is not None, reason)(func)

View File

@ -104,6 +104,9 @@ except ImportError:
SEED = 1234
MI300_ARCH = ("gfx942",)
MI200_ARCH = ("gfx90a")
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
NAVI3_ARCH = ("gfx1100", "gfx1101")
NAVI4_ARCH = ("gfx1200", "gfx1201")
class ProfilingMode(Enum):
LEGACY = 1