mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "add amp support for custom backend (#96188)"
This reverts commit cf12edee02a44009c4f06e36efa97d9a7372ab35. Reverted https://github.com/pytorch/pytorch/pull/96188 on behalf of https://github.com/kit1980 due to Broke some linalg tests : https://github.com/pytorch/pytorch/actions/runs/4420037607/jobs/7750708339
This commit is contained in:
@ -784,26 +784,6 @@ class DummyXPUModule:
|
||||
def is_available():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_autocast_foo_enabled():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_autocast_foo_dtype():
|
||||
return torch.float16
|
||||
|
||||
@staticmethod
|
||||
def set_autocast_foo_enabled(enable):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def set_autocast_foo_dtype(dtype):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_amp_supported_dtype():
|
||||
return [torch.float16]
|
||||
|
||||
|
||||
class TestExtensionUtils(TestCase):
|
||||
def test_external_module_register(self):
|
||||
@ -826,26 +806,6 @@ class TestExtensionUtils(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
||||
torch._register_device_module('xpu', DummyXPUModule)
|
||||
|
||||
def test_external_module_and_backend_register(self):
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
torch.utils.rename_privateuse1_backend('dummmy')
|
||||
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
self.assertEqual(custom_backend_name, 'foo')
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
torch.foo.is_available()
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
|
||||
with torch.autocast(device_type=custom_backend_name):
|
||||
pass
|
||||
torch._register_device_module('foo', DummyXPUModule)
|
||||
|
||||
torch.foo.is_available()
|
||||
with torch.autocast(device_type=custom_backend_name):
|
||||
pass
|
||||
|
||||
|
||||
class TestDeviceUtils(TestCase):
|
||||
def test_basic(self):
|
||||
|
Reference in New Issue
Block a user