Add way to actually delete a torch.library.Library object (#118318)

Relying on object lifetimes in Python is a bad idea due to reference
cycles. Previously, when a torch.library.Library object gets destroyed,
it clears all the registrations associated with it, but it's unclear
when it actually gets destroyed due to the existence of refcycles.

This PR:
- adds torch::Library::clear(), which deterministically releases all of
  the RAII registration handles of the torch::Library object
- adds a new `torch.library._scoped_library` context manager, which creates
  a library and cleans it up at the end of the scope using the previous item.
  All tests (unless they already handle library lifetimes) should use
  this new API
- Rewrites some flaky tests to use `_scoped_library`.

In the future we'll probably migrate all of our torch.library tests to
use `_scoped_library`, but that's kind of annoying because we have
multiple thousands of LOC

I'm hoping this will deflake those tests; we'll see.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118318
Approved by: https://github.com/albanD
This commit is contained in:
rzou
2024-01-26 11:08:49 -08:00
committed by PyTorch MergeBot
parent f129e3fe03
commit b256b7b348
6 changed files with 161 additions and 131 deletions

View File

@ -3,7 +3,7 @@
import tempfile
import torch
from copy import deepcopy
from torch.library import Library, impl, fallthrough_kernel
from torch.library import Library, impl, fallthrough_kernel, _scoped_library
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch import SymInt
from torch._subclasses.fake_tensor import FakeTensorMode
@ -48,48 +48,43 @@ class TestPythonRegistration(TestCase):
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
my_lib1 = Library("aten", "IMPL")
my_lib2 = Library("aten", "IMPL")
with _scoped_library("aten", "IMPL") as my_lib2:
with _scoped_library("aten", "IMPL") as my_lib1:
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl('neg', my_neg, "AutogradCPU")
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl('neg', my_neg, "AutogradCPU")
self.assertTrue(torch.neg(x).is_neg())
self.assertTrue(torch.neg(x).is_neg())
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
my_lib3 = Library("foo", "DEF")
my_lib3.define("neg(Tensor self) -> Tensor")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
del my_lib3
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
my_lib3 = Library("foo", "DEF")
my_lib3.define("neg(Tensor self) -> Tensor")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
del my_lib3
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overrided the behavior for the same before them
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overrided the behavior for the same before them
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
del my_lib1
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
del my_lib2
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
@ -419,33 +414,28 @@ class TestPythonRegistration(TestCase):
self.assertEqual(out_val, 13)
def test_register_functional_op_error_cases(self):
lib = Library(self.test_ns, "FRAGMENT")
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
register_functional_op(lib, "abs", torch.ops.aten.abs_)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
register_functional_op(lib, "abs", torch.ops.aten.abs_)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
schemas = [
'foo(Tensor x, Tensor(a!)[] y) -> ()',
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
]
del lib
schemas = [
'foo(Tensor x, Tensor(a!)[] y) -> ()',
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
]
for schema in schemas:
lib = Library(self.test_ns, "FRAGMENT")
try:
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(schema)
with self.assertRaisesRegex(RuntimeError, "NYI"):
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default)
finally:
del lib
delattr(torch.ops, self.test_ns)
def _check_is_functional_variant(self, mutable_op, functional_op, args):
# functional op should not mutate
@ -483,98 +473,97 @@ class TestPythonRegistration(TestCase):
self.assertTrue(has_functional_op)
def test_register_functional_op_no_returns(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_functional_op_with_optional(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
def foo_impl(x, y, z, w):
y.fill_(3.14)
z.fill_(2.71)
if w is not None:
w.fill_(1.618)
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
def foo_impl(x, y, z, w):
y.fill_(3.14)
z.fill_(2.71)
if w is not None:
w.fill_(1.618)
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
def test_register_functional_op_one_return(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
z.fill_(0.99)
return x.clone()
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
z.fill_(0.99)
return x.clone()
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_functional_op_multiple_returns(self):
lib = Library(self.test_ns, 'FRAGMENT')
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
return x.clone(), z.clone()
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
return x.clone(), z.clone()
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
lib.impl('foo', foo_impl, 'CPU')
register_functional_op(
lib,
'foo_functional',
getattr(torch.ops, self.test_ns).foo.default)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
def test_register_fallthrough(self):
try: