mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm][CI] Add support for gfx1100 in rocm workflow + test skips (#148355)
This PR adds infrastructure support for gfx1100 in the rocm workflow. Nodes have been allocated for this effort. @dnikolaev-amd contributed all the test skips. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148355 Approved by: https://github.com/jeffdaily Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com> Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
@ -1975,15 +1975,20 @@ def skipIfRocm(func=None, *, msg="test doesn't currently work on the ROCm stack"
|
||||
return dec_fn(func)
|
||||
return dec_fn
|
||||
|
||||
def getRocmArchName(device_index: int = 0):
|
||||
return torch.cuda.get_device_properties(device_index).gcnArchName
|
||||
|
||||
def isRocmArchAnyOf(arch: tuple[str, ...]):
|
||||
rocmArch = getRocmArchName()
|
||||
return any(x in rocmArch for x in arch)
|
||||
|
||||
def skipIfRocmArch(arch: tuple[str, ...]):
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if prop.gcnArchName.split(":")[0] in arch:
|
||||
reason = f"skipIfRocm: test skipped on {arch}"
|
||||
raise unittest.SkipTest(reason)
|
||||
if TEST_WITH_ROCM and isRocmArchAnyOf(arch):
|
||||
reason = f"skipIfRocm: test skipped on {arch}"
|
||||
raise unittest.SkipTest(reason)
|
||||
return fn(self, *args, **kwargs)
|
||||
return wrap_fn
|
||||
return dec_fn
|
||||
@ -2001,11 +2006,9 @@ def runOnRocmArch(arch: tuple[str, ...]):
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if prop.gcnArchName.split(":")[0] not in arch:
|
||||
reason = f"skipIfRocm: test only runs on {arch}"
|
||||
raise unittest.SkipTest(reason)
|
||||
if TEST_WITH_ROCM and not isRocmArchAnyOf(arch):
|
||||
reason = f"skipIfRocm: test only runs on {arch}"
|
||||
raise unittest.SkipTest(reason)
|
||||
return fn(self, *args, **kwargs)
|
||||
return wrap_fn
|
||||
return dec_fn
|
||||
@ -2055,15 +2058,18 @@ def skipIfHpu(fn):
|
||||
fn(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def getRocmVersion() -> tuple[int, int]:
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version
|
||||
rocm_version = _get_torch_rocm_version()
|
||||
return (rocm_version[0], rocm_version[1])
|
||||
|
||||
# Skips a test on CUDA if ROCm is available and its version is lower than requested.
|
||||
def skipIfRocmVersionLessThan(version=None):
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
if TEST_WITH_ROCM:
|
||||
rocm_version = str(torch.version.hip)
|
||||
rocm_version = rocm_version.split("-", maxsplit=1)[0] # ignore git sha
|
||||
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
||||
rocm_version_tuple = getRocmVersion()
|
||||
if rocm_version_tuple is None or version is None or rocm_version_tuple < tuple(version):
|
||||
reason = f"ROCm {rocm_version_tuple} is available but {version} required"
|
||||
raise unittest.SkipTest(reason)
|
||||
|
Reference in New Issue
Block a user