Make TestPythonRegistration clean up after itself (#102292)

We did this for TestCustomOp, now we are applying the same thing to
TestPythonRegistration.

This PR:
- changes TestPythonRegistration to register new ops under a single
namespace (self.test_ns)
- clean up the namespace by deleting it from torch.ops after each test
is done running.

This avoids a problem where if an op is re-defined, torch.ops.myns.op
crashes because we do some caching. The workaround in many of these
tests have been to just create an op with a different name, but this PR
makes it so that we don't need to do this.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102292
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
Richard Zou
2023-06-01 11:06:45 -07:00
committed by PyTorch MergeBot
parent 72cdbf6a3f
commit eaeea62ee4

View File

@ -39,6 +39,12 @@ class TestDispatcherPythonBindings(TestCase):
class TestPythonRegistration(TestCase):
test_ns = '_test_python_registration'
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
my_lib1 = Library("aten", "IMPL")
@ -95,7 +101,7 @@ class TestPythonRegistration(TestCase):
def test_finalizer(self):
impls_refcnt = sys.getrefcount(torch.library._impls)
lib = Library("_torch_testing", "FRAGMENT")
lib = Library(self.test_ns, "FRAGMENT")
lib.define("foo123(Tensor x) -> Tensor")
# 1 for `lib`, 1 for sys.getrefcount
@ -110,8 +116,8 @@ class TestPythonRegistration(TestCase):
def foo123(x):
pass
lib.impl("_torch_testing::foo123", foo123, "CPU")
key = '_torch_testing/foo123/CPU'
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
key = f'{self.test_ns}/foo123/CPU'
self.assertTrue(key in torch.library._impls)
saved_op_impls = lib._op_impls
@ -287,7 +293,7 @@ class TestPythonRegistration(TestCase):
del my_lib1
def test_create_new_library(self) -> None:
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")
my_lib1.define("sum(Tensor self) -> Tensor")
@ -297,12 +303,13 @@ class TestPythonRegistration(TestCase):
return args[0].clone()
x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum(x), x)
op = getattr(torch.ops, self.test_ns).sum
self.assertEqual(op(x), x)
my_lib2 = Library("foo", "IMPL")
my_lib2 = Library(self.test_ns, "IMPL")
# Example 2
@torch.library.impl(my_lib2, torch.ops.foo.sum.default, "ZeroTensor")
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
def my_sum_zt(*args, **kwargs):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
@ -310,14 +317,14 @@ class TestPythonRegistration(TestCase):
return args[0].clone()
y = torch._efficientzerotensor(3)
self.assertTrue(torch.ops.foo.sum(y)._is_zerotensor())
self.assertEqual(torch.ops.foo.sum(x), x)
self.assertTrue(op(y)._is_zerotensor())
self.assertEqual(op(x), x)
del my_lib2
del my_lib1
def test_create_new_library_fragment_no_existing(self):
my_lib = Library("foo", "FRAGMENT")
my_lib = Library(self.test_ns, "FRAGMENT")
my_lib.define("sum2(Tensor self) -> Tensor")
@ -326,15 +333,15 @@ class TestPythonRegistration(TestCase):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum2(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
del my_lib
def test_create_new_library_fragment_with_existing(self):
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")
# Create a fragment
my_lib2 = Library("foo", "FRAGMENT")
my_lib2 = Library(self.test_ns, "FRAGMENT")
my_lib2.define("sum4(Tensor self) -> Tensor")
@ -343,10 +350,10 @@ class TestPythonRegistration(TestCase):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum4(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
# Create another fragment
my_lib3 = Library("foo", "FRAGMENT")
my_lib3 = Library(self.test_ns, "FRAGMENT")
my_lib3.define("sum3(Tensor self) -> Tensor")
@ -355,7 +362,7 @@ class TestPythonRegistration(TestCase):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum3(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
del my_lib1
del my_lib2
@ -364,7 +371,7 @@ class TestPythonRegistration(TestCase):
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
def test_alias_analysis(self):
def test_helper(alias_analysis=""):
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")
called = [0]
@ -374,9 +381,9 @@ class TestPythonRegistration(TestCase):
@torch.jit.script
def _test():
torch.ops.foo._op()
torch.ops._test_python_registration._op()
assert "foo::_op" in str(_test.graph)
assert "_test_python_registration::_op" in str(_test.graph)
with self.assertRaises(AssertionError):
test_helper("") # alias_analysis="FROM_SCHEMA"
@ -399,14 +406,14 @@ class TestPythonRegistration(TestCase):
s0, s1 = ft.shape
tlib = Library("tlib", "DEF")
tlib = Library(self.test_ns, "DEF")
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
def sqsum(a: SymInt, b: SymInt):
return a * a + b * b
out = torch.ops.tlib.sqsum.default(s0, s1)
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
out_val = shape_env.evaluate_expr(out.node.expr)
self.assertEquals(out_val, 13)