[ROCm] Enabled JIT UTs on ROCm (#164582)

This PR is to enable the following tests rocm.

test/test_jit.py::TestBackends::test_save_load
test/test_jit.py::TestBackends::test_execution
test/test_jit.py::TestBackends::test_errors
test/test_jit.py::TestCUDA::test_current_stream

Verified that the tests pass on AMD gfx90a and gfx942 arch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164582
Approved by: https://github.com/jeffdaily
This commit is contained in:
rraminen
2025-10-03 20:16:38 +00:00
committed by PyTorch MergeBot
parent 8ec8c14ace
commit da49a57d34
2 changed files with 7 additions and 16 deletions

View File

@ -16,8 +16,6 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
raise_on_run_directly,
skipIfRocm,
TEST_WITH_ROCM,
)
from torch.testing._internal.jit_utils import JitTestCase
@ -61,7 +59,7 @@ class BasicModule(torch.nn.Module):
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCase(JitTestCase):
@ -144,7 +142,6 @@ class BasicModuleTest(JitBackendTestCase):
self.check_function("sub_accum", (input, input))
self.check_function("forward", (input, input))
@skipIfRocm
def test_save_load(self):
# Lowered module should produce the same outputs.
self.test_execution()
@ -203,7 +200,6 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
backend_method = self.lowered_module.__getattr__("forward")
backend_method(*(input, input))
@skipIfRocm
def test_save_load(self):
# Test that saving the lowered module is OK but loading fails because the backend is not available.
buffer = io.BytesIO()
@ -447,7 +443,7 @@ class SelectiveLoweringTest(JitBackendTestCase):
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class TestBackends(JitTestCase):
@ -465,27 +461,23 @@ class TestBackends(JitTestCase):
def setUp(self):
super().setUp()
if not TEST_WITH_ROCM:
self.basic_module_test.setUp()
self.basic_module_unavailable_test.setUp()
self.nested_module_test.setUp()
self.selective_lowering_test.setUp()
self.basic_module_test.setUp()
self.basic_module_unavailable_test.setUp()
self.nested_module_test.setUp()
self.selective_lowering_test.setUp()
@skipIfRocm
def test_execution(self):
self.basic_module_test.test_execution()
self.basic_module_unavailable_test.test_execution()
self.nested_module_test.test_execution()
self.selective_lowering_test.test_execution()
@skipIfRocm
def test_save_load(self):
self.basic_module_test.test_save_load()
self.basic_module_unavailable_test.test_save_load()
self.nested_module_test.test_save_load()
self.selective_lowering_test.test_save_load()
@skipIfRocm
def test_errors(self):
self.selective_lowering_test.test_errors()
@ -510,7 +502,7 @@ class BasicModuleAdd(torch.nn.Module):
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCaseWithCompiler(JitTestCase):

View File

@ -121,7 +121,6 @@ class TestCUDA(JitTestCase):
self.assertTrue(event_default_args)
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_current_stream(self):
# Test current stream on the device and check if the stream device index