mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add way to actually delete a torch.library.Library object (#118318)
Relying on object lifetimes in Python is a bad idea due to reference cycles. Previously, when a torch.library.Library object gets destroyed, it clears all the registrations associated with it, but it's unclear when it actually gets destroyed due to the existence of refcycles. This PR: - adds torch::Library::clear(), which deterministically releases all of the RAII registration handles of the torch::Library object - adds a new `torch.library._scoped_library` context manager, which creates a library and cleans it up at the end of the scope using the previous item. All tests (unless they already handle library lifetimes) should use this new API - Rewrites some flaky tests to use `_scoped_library`. In the future we'll probably migrate all of our torch.library tests to use `_scoped_library`, but that's kind of annoying because we have multiple thousands of LOC I'm hoping this will deflake those tests; we'll see. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118318 Approved by: https://github.com/albanD
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
import tempfile
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from torch.library import Library, impl, fallthrough_kernel
|
||||
from torch.library import Library, impl, fallthrough_kernel, _scoped_library
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch import SymInt
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
@ -48,48 +48,43 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
def test_override_aten_ops_with_multiple_libraries(self) -> None:
|
||||
x = torch.tensor([1, 2])
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib2 = Library("aten", "IMPL")
|
||||
with _scoped_library("aten", "IMPL") as my_lib2:
|
||||
with _scoped_library("aten", "IMPL") as my_lib1:
|
||||
# Example 1
|
||||
def my_neg(*args, **kwargs):
|
||||
return args[0]._neg_view()
|
||||
|
||||
# Example 1
|
||||
def my_neg(*args, **kwargs):
|
||||
return args[0]._neg_view()
|
||||
# Now we are secretly making the operator a view op so autograd needs to know how
|
||||
# to handle it
|
||||
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
||||
|
||||
# Now we are secretly making the operator a view op so autograd needs to know how
|
||||
# to handle it
|
||||
my_lib1.impl('neg', my_neg, "AutogradCPU")
|
||||
self.assertTrue(torch.neg(x).is_neg())
|
||||
|
||||
self.assertTrue(torch.neg(x).is_neg())
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "DEF")
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "DEF")
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
# torch.ops.aten.mul.Tensor
|
||||
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
||||
|
||||
# torch.ops.aten.mul.Tensor
|
||||
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
|
||||
y = torch._efficientzerotensor(2)
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
y = torch._efficientzerotensor(2)
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
||||
# combination if someone overrided the behavior for the same before them
|
||||
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
||||
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
||||
|
||||
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
|
||||
# combination if someone overrided the behavior for the same before them
|
||||
with self.assertRaisesRegex(RuntimeError, 'already a kernel registered from python'):
|
||||
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
|
||||
|
||||
del my_lib1
|
||||
|
||||
# Validate that lib2 is not affected by removing lib1
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
del my_lib2
|
||||
# Validate that lib2 is not affected by removing lib1
|
||||
self.assertFalse(torch.mul(x, y)._is_zerotensor())
|
||||
|
||||
# Validate that the old behavior is restored for neg and mul
|
||||
self.assertFalse(torch.neg(x).is_neg())
|
||||
@ -419,33 +414,28 @@ class TestPythonRegistration(TestCase):
|
||||
self.assertEqual(out_val, 13)
|
||||
|
||||
def test_register_functional_op_error_cases(self):
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
|
||||
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
|
||||
|
||||
schemas = [
|
||||
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
||||
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
||||
]
|
||||
del lib
|
||||
schemas = [
|
||||
'foo(Tensor x, Tensor(a!)[] y) -> ()',
|
||||
'foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)',
|
||||
'foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))',
|
||||
]
|
||||
|
||||
for schema in schemas:
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
try:
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define(schema)
|
||||
with self.assertRaisesRegex(RuntimeError, "NYI"):
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
finally:
|
||||
del lib
|
||||
delattr(torch.ops, self.test_ns)
|
||||
|
||||
def _check_is_functional_variant(self, mutable_op, functional_op, args):
|
||||
# functional op should not mutate
|
||||
@ -483,98 +473,97 @@ class TestPythonRegistration(TestCase):
|
||||
self.assertTrue(has_functional_op)
|
||||
|
||||
def test_register_functional_op_no_returns(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_functional_op_with_optional(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
z.fill_(2.71)
|
||||
if w is not None:
|
||||
w.fill_(1.618)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
z.fill_(2.71)
|
||||
if w is not None:
|
||||
w.fill_(1.618)
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, None))
|
||||
|
||||
def test_register_functional_op_one_return(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
z.fill_(0.99)
|
||||
return x.clone()
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
z.fill_(0.99)
|
||||
return x.clone()
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
"foo_functional",
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_functional_op_multiple_returns(self):
|
||||
lib = Library(self.test_ns, 'FRAGMENT')
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||
lib.define('foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)')
|
||||
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
return x.clone(), z.clone()
|
||||
def foo_impl(x, y, z, w):
|
||||
y.fill_(3.14)
|
||||
w.fill_(2.71)
|
||||
return x.clone(), z.clone()
|
||||
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
lib.impl('foo', foo_impl, 'CPU')
|
||||
register_functional_op(
|
||||
lib,
|
||||
'foo_functional',
|
||||
getattr(torch.ops, self.test_ns).foo.default)
|
||||
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
z = torch.randn([])
|
||||
w = torch.randn([])
|
||||
self._check_is_functional_variant(
|
||||
getattr(torch.ops, self.test_ns).foo.default,
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_fallthrough(self):
|
||||
try:
|
||||
|
Reference in New Issue
Block a user