mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make TestPythonRegistration clean up after itself (#102292)
We did this for TestCustomOp, now we are applying the same thing to TestPythonRegistration. This PR: - changes TestPythonRegistration to register new ops under a single namespace (self.test_ns) - clean up the namespace by deleting it from torch.ops after each test is done running. This avoids a problem where if an op is re-defined, torch.ops.myns.op crashes because we do some caching. The workaround in many of these tests have been to just create an op with a different name, but this PR makes it so that we don't need to do this. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/102292 Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
72cdbf6a3f
commit
eaeea62ee4
@ -39,6 +39,12 @@ class TestDispatcherPythonBindings(TestCase):
|
||||
|
||||
|
||||
class TestPythonRegistration(TestCase):
|
||||
test_ns = '_test_python_registration'
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(torch.ops, self.test_ns):
|
||||
del torch.ops._test_python_registration
|
||||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
@ -95,7 +101,7 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
def test_finalizer(self):
|
||||
impls_refcnt = sys.getrefcount(torch.library._impls)
|
||||
lib = Library("_torch_testing", "FRAGMENT")
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
lib.define("foo123(Tensor x) -> Tensor")
|
||||
|
||||
# 1 for `lib`, 1 for sys.getrefcount
|
||||
@ -110,8 +116,8 @@ class TestPythonRegistration(TestCase):
|
||||
def foo123(x):
|
||||
pass
|
||||
|
||||
lib.impl("_torch_testing::foo123", foo123, "CPU")
|
||||
key = '_torch_testing/foo123/CPU'
|
||||
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
|
||||
key = f'{self.test_ns}/foo123/CPU'
|
||||
self.assertTrue(key in torch.library._impls)
|
||||
|
||||
saved_op_impls = lib._op_impls
|
||||
@ -287,7 +293,7 @@ class TestPythonRegistration(TestCase):
|
||||
del my_lib1
|
||||
|
||||
def test_create_new_library(self) -> None:
|
||||
my_lib1 = Library("foo", "DEF")
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
|
||||
my_lib1.define("sum(Tensor self) -> Tensor")
|
||||
|
||||
@ -297,12 +303,13 @@ class TestPythonRegistration(TestCase):
|
||||
return args[0].clone()
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.ops.foo.sum(x), x)
|
||||
op = getattr(torch.ops, self.test_ns).sum
|
||||
self.assertEqual(op(x), x)
|
||||
|
||||
my_lib2 = Library("foo", "IMPL")
|
||||
my_lib2 = Library(self.test_ns, "IMPL")
|
||||
|
||||
# Example 2
|
||||
@torch.library.impl(my_lib2, torch.ops.foo.sum.default, "ZeroTensor")
|
||||
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
||||
def my_sum_zt(*args, **kwargs):
|
||||
if args[0]._is_zerotensor():
|
||||
return torch._efficientzerotensor(args[0].shape)
|
||||
@ -310,14 +317,14 @@ class TestPythonRegistration(TestCase):
|
||||
return args[0].clone()
|
||||
|
||||
y = torch._efficientzerotensor(3)
|
||||
self.assertTrue(torch.ops.foo.sum(y)._is_zerotensor())
|
||||
self.assertEqual(torch.ops.foo.sum(x), x)
|
||||
self.assertTrue(op(y)._is_zerotensor())
|
||||
self.assertEqual(op(x), x)
|
||||
|
||||
del my_lib2
|
||||
del my_lib1
|
||||
|
||||
def test_create_new_library_fragment_no_existing(self):
|
||||
my_lib = Library("foo", "FRAGMENT")
|
||||
my_lib = Library(self.test_ns, "FRAGMENT")
|
||||
|
||||
my_lib.define("sum2(Tensor self) -> Tensor")
|
||||
|
||||
@ -326,15 +333,15 @@ class TestPythonRegistration(TestCase):
|
||||
return args[0]
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.ops.foo.sum2(x), x)
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
||||
|
||||
del my_lib
|
||||
|
||||
def test_create_new_library_fragment_with_existing(self):
|
||||
my_lib1 = Library("foo", "DEF")
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
|
||||
# Create a fragment
|
||||
my_lib2 = Library("foo", "FRAGMENT")
|
||||
my_lib2 = Library(self.test_ns, "FRAGMENT")
|
||||
|
||||
my_lib2.define("sum4(Tensor self) -> Tensor")
|
||||
|
||||
@ -343,10 +350,10 @@ class TestPythonRegistration(TestCase):
|
||||
return args[0]
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.ops.foo.sum4(x), x)
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
||||
|
||||
# Create another fragment
|
||||
my_lib3 = Library("foo", "FRAGMENT")
|
||||
my_lib3 = Library(self.test_ns, "FRAGMENT")
|
||||
|
||||
my_lib3.define("sum3(Tensor self) -> Tensor")
|
||||
|
||||
@ -355,7 +362,7 @@ class TestPythonRegistration(TestCase):
|
||||
return args[0]
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.ops.foo.sum3(x), x)
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
||||
|
||||
del my_lib1
|
||||
del my_lib2
|
||||
@ -364,7 +371,7 @@ class TestPythonRegistration(TestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
|
||||
def test_alias_analysis(self):
|
||||
def test_helper(alias_analysis=""):
|
||||
my_lib1 = Library("foo", "DEF")
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
|
||||
called = [0]
|
||||
|
||||
@ -374,9 +381,9 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
@torch.jit.script
|
||||
def _test():
|
||||
torch.ops.foo._op()
|
||||
torch.ops._test_python_registration._op()
|
||||
|
||||
assert "foo::_op" in str(_test.graph)
|
||||
assert "_test_python_registration::_op" in str(_test.graph)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_helper("") # alias_analysis="FROM_SCHEMA"
|
||||
@ -399,14 +406,14 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
s0, s1 = ft.shape
|
||||
|
||||
tlib = Library("tlib", "DEF")
|
||||
tlib = Library(self.test_ns, "DEF")
|
||||
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
||||
|
||||
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
||||
def sqsum(a: SymInt, b: SymInt):
|
||||
return a * a + b * b
|
||||
|
||||
out = torch.ops.tlib.sqsum.default(s0, s1)
|
||||
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
||||
out_val = shape_env.evaluate_expr(out.node.expr)
|
||||
self.assertEquals(out_val, 13)
|
||||
|
||||
|
Reference in New Issue
Block a user