mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][SymmMem] re-enable UTs (#162811)
After the UT suite moved to `MultiProcContinuousTest`, `skipIfRocm` decorator started failing rather than skipping UTs because now we spawn multiple threads before the skip decorator is taken into account and the skip decorator was raising an exception to exit the process. But, the parent process treated the child process exiting as a crash rather than a skip. Additionally, in `MultiProcContinuousTest`, if one UT fails all subsequent ones are also skipped which makes sense since there's one setup for the entire suite. However, this showed up as many failing/skipped UTs in the parity. I added multiprocess version of skip decorators for ROCm, including, `skip_if_rocm_arch_multiprocess` and `skip_if_rocm_ver_lessthan_multiprocess`. These are needed as symmetric memory feature is only supported on MI300 onwards and we need to skip them for other archs and some UTs only work after ROCm7.0. Fixes #161249 Fixes #161187 Fixes #161078 Fixes #160989 Fixes #160881 Fixes #160768 Fixes #160716 Fixes #160665 Fixes #160621 Fixes #160549 Fixes #160506 Fixes #160445 Fixes #160347 Fixes #160203 Fixes #160177 Fixes #160049 Fixes #159921 Fixes #159764 Fixes #159643 Fixes #159499 Fixes #159397 Fixes #159396 Fixes #159347 Fixes #159067 Fixes #159066 Fixes #158916 Fixes #158760 Fixes #158759 Fixes #158422 Fixes #158138 Fixes #158136 Fixes #158135 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/162811 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ee071aa85
commit
f638854e1d
@ -36,6 +36,7 @@ from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
find_free_port,
|
||||
IS_SANDCASTLE,
|
||||
LazyVal,
|
||||
retry_on_connect_failures,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
@ -421,17 +422,62 @@ def requires_multicast_support():
|
||||
)
|
||||
|
||||
|
||||
def evaluate_platform_supports_symm_mem():
|
||||
if TEST_WITH_ROCM:
|
||||
arch_list = ["gfx942", "gfx950"]
|
||||
for arch in arch_list:
|
||||
if arch in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
return True
|
||||
if TEST_CUDA:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
PLATFORM_SUPPORTS_SYMM_MEM: bool = LazyVal(
|
||||
lambda: evaluate_platform_supports_symm_mem()
|
||||
)
|
||||
|
||||
|
||||
def skip_if_rocm_multiprocess(func):
|
||||
"""Skips a test for ROCm"""
|
||||
func.skip_if_rocm_multiprocess = True
|
||||
"""Skips a test for ROCm multiprocess UTs"""
|
||||
return unittest.skipIf(TEST_WITH_ROCM, TEST_SKIPS["skipIfRocm"].message)(func)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not TEST_WITH_ROCM:
|
||||
return func(*args, **kwargs)
|
||||
sys.exit(TEST_SKIPS["skipIfRocm"].exit_code)
|
||||
|
||||
return wrapper
|
||||
def skip_if_rocm_arch_multiprocess(arch: tuple[str, ...]):
|
||||
"""Skips a test for given ROCm archs - multiprocess UTs"""
|
||||
|
||||
def decorator(func):
|
||||
prop = torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
|
||||
arch_match = prop in arch
|
||||
reason = None
|
||||
if TEST_WITH_ROCM and arch_match:
|
||||
reason = f"skip_if_rocm_arch_multiprocess: test skipped on {arch}"
|
||||
|
||||
return unittest.skipIf(reason is not None, reason)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def skip_if_rocm_ver_lessthan_multiprocess(version=None):
|
||||
"""Skips a test for ROCm based on ROCm ver - multiprocess UTs"""
|
||||
|
||||
def decorator(func):
|
||||
reason = None
|
||||
if TEST_WITH_ROCM:
|
||||
rocm_version = str(torch.version.hip)
|
||||
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
|
||||
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
||||
if (
|
||||
rocm_version_tuple is None
|
||||
or version is None
|
||||
or rocm_version_tuple < tuple(version)
|
||||
):
|
||||
reason = f"skip_if_rocm_ver_lessthan_multiprocess: ROCm {rocm_version_tuple} is available but {version} required"
|
||||
|
||||
return unittest.skipIf(reason is not None, reason)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def skip_if_win32():
|
||||
|
Reference in New Issue
Block a user