mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Enable several distributed UTs (#164390)
Increase the tolerance for the following UTs as there was a slight mismatch seen on MI200. - test_data_parallel.py:test_strided_grad_layout - test_c10d_nccl.py:test_grad_layout_1devicemodule_1replicaperprocess Skip for MI200: - test_fully_shard_training.py:test_2d_mlp_with_nd_mesh - test_2d_composability.py:test_train_parity_2d_mlp - test_fully_shard_overlap.py:test_fully_shard_training_overlap Fixes #159489 Fixes #159488 Fixes #152700 Fixes #125555 Fixes #134139 Working as is on both MI200 and MI300: Fixes #125991 Fixes #125918 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164390 Approved by: https://github.com/jeffdaily
This commit is contained in:
@ -444,11 +444,11 @@ def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
|
||||
"""Skips a test for given ROCm archs - multiprocess UTs"""
|
||||
|
||||
def decorator(func):
|
||||
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
||||
arch_match = prop in arch
|
||||
reason = None
|
||||
if TEST_WITH_ROCM and arch_match:
|
||||
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
|
||||
if TEST_WITH_ROCM:
|
||||
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
||||
if prop in arch:
|
||||
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
|
||||
|
||||
return unittest.skipIf(reason is not None, reason)(func)
|
||||
|
||||
|
@ -104,6 +104,9 @@ except ImportError:
|
||||
SEED = 1234
|
||||
MI300_ARCH = ("gfx942",)
|
||||
MI200_ARCH = ("gfx90a")
|
||||
NAVI_ARCH = ("gfx1030", "gfx1100", "gfx1101", "gfx1200", "gfx1201")
|
||||
NAVI3_ARCH = ("gfx1100", "gfx1101")
|
||||
NAVI4_ARCH = ("gfx1200", "gfx1201")
|
||||
|
||||
class ProfilingMode(Enum):
|
||||
LEGACY = 1
|
||||
|
Reference in New Issue
Block a user