[ROCm][SymmMem] re-enable UTs (#162811)

After the UT suite moved to `MultiProcContinuousTest`, `skipIfRocm` decorator started failing rather than skipping UTs because now we spawn multiple threads before the skip decorator is taken into account and the skip decorator was raising an exception to exit the process. But, the parent process treated the child process exiting as a crash rather than a skip. Additionally, in `MultiProcContinuousTest`, if one UT fails all subsequent ones are also skipped which makes sense since there's one setup for the entire suite. However, this showed up as many failing/skipped UTs in the parity.

I added multiprocess version of skip decorators for ROCm, including, `skip_if_rocm_arch_multiprocess` and
`skip_if_rocm_ver_lessthan_multiprocess`. These are needed as symmetric memory feature is only supported on MI300 onwards and we need to skip them for other archs and some UTs only work after ROCm7.0.

Fixes #161249
Fixes #161187
Fixes #161078
Fixes #160989
Fixes #160881
Fixes #160768
Fixes #160716
Fixes #160665
Fixes #160621
Fixes #160549
Fixes #160506
Fixes #160445
Fixes #160347
Fixes #160203
Fixes #160177
Fixes #160049
Fixes #159921
Fixes #159764
Fixes #159643
Fixes #159499
Fixes #159397
Fixes #159396
Fixes #159347
Fixes #159067
Fixes #159066
Fixes #158916
Fixes #158760
Fixes #158759
Fixes #158422
Fixes #158138
Fixes #158136
Fixes #158135

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162811
Approved by: https://github.com/jeffdaily
This commit is contained in:
Prachi Gupta
2025-09-16 15:35:35 +00:00
committed by PyTorch MergeBot
parent 3ee071aa85
commit f638854e1d
3 changed files with 123 additions and 38 deletions

View File

@ -36,6 +36,7 @@ from torch.testing._internal.common_utils import (
FILE_SCHEMA,
find_free_port,
IS_SANDCASTLE,
LazyVal,
retry_on_connect_failures,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
@ -421,17 +422,62 @@ def requires_multicast_support():
)
def evaluate_platform_supports_symm_mem():
if TEST_WITH_ROCM:
arch_list = ["gfx942", "gfx950"]
for arch in arch_list:
if arch in torch.cuda.get_device_properties(0).gcnArchName:
return True
if TEST_CUDA:
return True
return False
PLATFORM_SUPPORTS_SYMM_MEM: bool = LazyVal(
lambda: evaluate_platform_supports_symm_mem()
)
def skip_if_rocm_multiprocess(func):
"""Skips a test for ROCm"""
func.skip_if_rocm_multiprocess = True
"""Skips a test for ROCm multiprocess UTs"""
return unittest.skipIf(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func)
@wraps(func)
def wrapper(*args, **kwargs):
if not TEST_WITH_ROCM:
return func(*args, **kwargs)
sys.exit(TEST_SKIPS["skipIfRocm"].exit_code)
return wrapper
def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
"""Skips a test for given ROCm archs - multiprocess UTs"""
def decorator(func):
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
arch_match = prop in arch
reason = None
if TEST_WITH_ROCM and arch_match:
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
return unittest.skipIf(reason is not None, reason)(func)
return decorator
def skip_if_rocm_ver_lessthan_multiprocess(version=None):
"""Skips a test for ROCm based on ROCm ver - multiprocess UTs"""
def decorator(func):
reason = None
if TEST_WITH_ROCM:
rocm_version = str(torch.version.hip)
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
if (
rocm_version_tuple is None
or version is None
or rocm_version_tuple < tuple(version)
):
reason = f"skip_if_rocm_ver_lessthan_multiprocess: ROCm {rocm_version_tuple} is available but {version} required"
return unittest.skipIf(reason is not None, reason)(func)
return decorator
def skip_if_win32():