diff --git a/test/run_test.py b/test/run_test.py index f91f0cc323cf..6f8126c5fef3 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -531,6 +531,7 @@ XPU_TEST = [ # The tests inside these files should never be run in parallel with each other RUN_PARALLEL_BLOCKLIST = [ + "test_extension_utils", "test_cpp_extensions_jit", "test_cpp_extensions_open_device_registration", "test_cpp_extensions_stream_and_event", diff --git a/test/test_extension_utils.py b/test/test_extension_utils.py new file mode 100644 index 000000000000..d624f93a1d3f --- /dev/null +++ b/test/test_extension_utils.py @@ -0,0 +1,88 @@ +# Owner(s): ["module: PrivateUse1"] +import sys + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +class DummyPrivateUse1Module: + @staticmethod + def is_available(): + return True + + @staticmethod + def is_autocast_enabled(): + return True + + @staticmethod + def get_autocast_dtype(): + return torch.float16 + + @staticmethod + def set_autocast_enabled(enable): + pass + + @staticmethod + def set_autocast_dtype(dtype): + pass + + @staticmethod + def get_amp_supported_dtype(): + return [torch.float16] + + +class TestExtensionUtils(TestCase): + def tearDown(self): + # Clean up + backend_name = torch._C._get_privateuse1_backend_name() + if hasattr(torch, backend_name): + delattr(torch, backend_name) + if f"torch.{backend_name}" in sys.modules: + del sys.modules[f"torch.{backend_name}"] + + def test_external_module_register(self): + # Built-in module + with self.assertRaisesRegex(RuntimeError, "The runtime module of"): + torch._register_device_module("cuda", torch.cuda) + + # Wrong device type + with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): + torch._register_device_module("dummmy", DummyPrivateUse1Module) + + with self.assertRaises(AttributeError): + torch.privateuseone.is_available() # type: ignore[attr-defined] + + torch._register_device_module("privateuseone", DummyPrivateUse1Module) + + torch.privateuseone.is_available() # type: ignore[attr-defined] + + # No supporting for override + with self.assertRaisesRegex(RuntimeError, "The runtime module of"): + torch._register_device_module("privateuseone", DummyPrivateUse1Module) + + def test_external_module_register_with_renamed_backend(self): + torch.utils.rename_privateuse1_backend("foo") + with self.assertRaisesRegex(RuntimeError, "has already been set"): + torch.utils.rename_privateuse1_backend("dummmy") + + custom_backend_name = torch._C._get_privateuse1_backend_name() + self.assertEqual(custom_backend_name, "foo") + + with self.assertRaises(AttributeError): + torch.foo.is_available() # type: ignore[attr-defined] + + with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): + with torch.autocast(device_type=custom_backend_name): + pass + torch._register_device_module("foo", DummyPrivateUse1Module) + + torch.foo.is_available() # type: ignore[attr-defined] + with torch.autocast(device_type=custom_backend_name): + pass + + self.assertEqual(torch._utils._get_device_index("foo:1"), 1) + self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_utils.py b/test/test_utils.py index 6ddc600be21f..0fdd59edce58 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1002,85 +1002,6 @@ class TestStandaloneCPPJIT(TestCase): shutil.rmtree(build_dir) -class DummyPrivateUse1Module: - @staticmethod - def is_available(): - return True - - @staticmethod - def is_autocast_enabled(): - return True - - @staticmethod - def get_autocast_dtype(): - return torch.float16 - - @staticmethod - def set_autocast_enabled(enable): - pass - - @staticmethod - def set_autocast_dtype(dtype): - pass - - @staticmethod - def get_amp_supported_dtype(): - return [torch.float16] - - -class TestExtensionUtils(TestCase): - def tearDown(self): - # Clean up - backend_name = torch._C._get_privateuse1_backend_name() - if hasattr(torch, backend_name): - delattr(torch, backend_name) - if f"torch.{backend_name}" in sys.modules: - del sys.modules[f"torch.{backend_name}"] - - def test_external_module_register(self): - # Built-in module - with self.assertRaisesRegex(RuntimeError, "The runtime module of"): - torch._register_device_module("cuda", torch.cuda) - - # Wrong device type - with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"): - torch._register_device_module("dummmy", DummyPrivateUse1Module) - - with self.assertRaises(AttributeError): - torch.privateuseone.is_available() # type: ignore[attr-defined] - - torch._register_device_module("privateuseone", DummyPrivateUse1Module) - - torch.privateuseone.is_available() # type: ignore[attr-defined] - - # No supporting for override - with self.assertRaisesRegex(RuntimeError, "The runtime module of"): - torch._register_device_module("privateuseone", DummyPrivateUse1Module) - - def test_external_module_register_with_renamed_backend(self): - torch.utils.rename_privateuse1_backend("foo") - with self.assertRaisesRegex(RuntimeError, "has already been set"): - torch.utils.rename_privateuse1_backend("dummmy") - - custom_backend_name = torch._C._get_privateuse1_backend_name() - self.assertEqual(custom_backend_name, "foo") - - with self.assertRaises(AttributeError): - torch.foo.is_available() # type: ignore[attr-defined] - - with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): - with torch.autocast(device_type=custom_backend_name): - pass - torch._register_device_module("foo", DummyPrivateUse1Module) - - torch.foo.is_available() # type: ignore[attr-defined] - with torch.autocast(device_type=custom_backend_name): - pass - - self.assertEqual(torch._utils._get_device_index("foo:1"), 1) - self.assertEqual(torch._utils._get_device_index(torch.device("foo:2")), 2) - - class TestRenderUtils(TestCase): def test_basic(self): self.assertExpectedInline(