mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] Enabled JIT UTs on ROCm (#164582)
This PR is to enable the following tests rocm. test/test_jit.py::TestBackends::test_save_load test/test_jit.py::TestBackends::test_execution test/test_jit.py::TestBackends::test_errors test/test_jit.py::TestCUDA::test_current_stream Verified that the tests pass on AMD gfx90a and gfx942 arch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164582 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
8ec8c14ace
commit
da49a57d34
@ -16,8 +16,6 @@ from torch.testing._internal.common_utils import (
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
raise_on_run_directly,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
@ -61,7 +59,7 @@ class BasicModule(torch.nn.Module):
|
||||
|
||||
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class JitBackendTestCase(JitTestCase):
|
||||
@ -144,7 +142,6 @@ class BasicModuleTest(JitBackendTestCase):
|
||||
self.check_function("sub_accum", (input, input))
|
||||
self.check_function("forward", (input, input))
|
||||
|
||||
@skipIfRocm
|
||||
def test_save_load(self):
|
||||
# Lowered module should produce the same outputs.
|
||||
self.test_execution()
|
||||
@ -203,7 +200,6 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
|
||||
backend_method = self.lowered_module.__getattr__("forward")
|
||||
backend_method(*(input, input))
|
||||
|
||||
@skipIfRocm
|
||||
def test_save_load(self):
|
||||
# Test that saving the lowered module is OK but loading fails because the backend is not available.
|
||||
buffer = io.BytesIO()
|
||||
@ -447,7 +443,7 @@ class SelectiveLoweringTest(JitBackendTestCase):
|
||||
|
||||
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class TestBackends(JitTestCase):
|
||||
@ -465,27 +461,23 @@ class TestBackends(JitTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not TEST_WITH_ROCM:
|
||||
self.basic_module_test.setUp()
|
||||
self.basic_module_unavailable_test.setUp()
|
||||
self.nested_module_test.setUp()
|
||||
self.selective_lowering_test.setUp()
|
||||
self.basic_module_test.setUp()
|
||||
self.basic_module_unavailable_test.setUp()
|
||||
self.nested_module_test.setUp()
|
||||
self.selective_lowering_test.setUp()
|
||||
|
||||
@skipIfRocm
|
||||
def test_execution(self):
|
||||
self.basic_module_test.test_execution()
|
||||
self.basic_module_unavailable_test.test_execution()
|
||||
self.nested_module_test.test_execution()
|
||||
self.selective_lowering_test.test_execution()
|
||||
|
||||
@skipIfRocm
|
||||
def test_save_load(self):
|
||||
self.basic_module_test.test_save_load()
|
||||
self.basic_module_unavailable_test.test_save_load()
|
||||
self.nested_module_test.test_save_load()
|
||||
self.selective_lowering_test.test_save_load()
|
||||
|
||||
@skipIfRocm
|
||||
def test_errors(self):
|
||||
self.selective_lowering_test.test_errors()
|
||||
|
||||
@ -510,7 +502,7 @@ class BasicModuleAdd(torch.nn.Module):
|
||||
|
||||
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class JitBackendTestCaseWithCompiler(JitTestCase):
|
||||
|
@ -121,7 +121,6 @@ class TestCUDA(JitTestCase):
|
||||
|
||||
self.assertTrue(event_default_args)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_current_stream(self):
|
||||
# Test current stream on the device and check if the stream device index
|
||||
|
Reference in New Issue
Block a user