[ROCm] Enable several distributed UTs (#164390)

Increase the tolerance for the following UTs as there was a slight mismatch seen on MI200.
    - test_data_parallel.py:test_strided_grad_layout
    - test_c10d_nccl.py:test_grad_layout_1devicemodule_1replicaperprocess

Skip for MI200:
    - test_fully_shard_training.py:test_2d_mlp_with_nd_mesh
    - test_2d_composability.py:test_train_parity_2d_mlp
    - test_fully_shard_overlap.py:test_fully_shard_training_overlap

Fixes #159489
Fixes #159488
Fixes #152700
Fixes #125555
Fixes #134139

Working as is on both MI200 and MI300:
Fixes #125991
Fixes #125918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164390
Approved by: https://github.com/jeffdaily
This commit is contained in:
Prachi
2025-10-03 19:52:51 +00:00
committed by PyTorch MergeBot
parent 1bb68271b7
commit 3ca09d65f1
7 changed files with 32 additions and 10 deletions

View File

@ -444,11 +444,11 @@ def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
"""Skips a test for given ROCm archs - multiprocess UTs"""
def decorator(func):
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
arch_match = prop in arch
reason = None
if TEST_WITH_ROCM and arch_match:
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
if TEST_WITH_ROCM:
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
if prop in arch:
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
return unittest.skipIf(reason is not None, reason)(func)