diff --git a/test/test_autograd_fallback.py b/test/test_autograd_fallback.py index d32bf870841b..d6252ac6f34a 100644 --- a/test/test_autograd_fallback.py +++ b/test/test_autograd_fallback.py @@ -6,7 +6,7 @@ import warnings import numpy as np import torch -from torch.library import _scoped_library, Library +from torch.library import _scoped_library from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -28,20 +28,24 @@ def autograd_fallback_mode(mode): class TestAutogradFallback(TestCase): test_ns = "_test_autograd_fallback" + def setUp(self): + super().setUp() + self.libraries = [] + def tearDown(self): if hasattr(torch.ops, self.test_ns): delattr(torch.ops, self.test_ns) - if hasattr(self, "lib"): - del self.lib.m - del self.lib + for lib in self.libraries: + lib._destroy() + del self.libraries def get_op(self, name): return getattr(getattr(torch.ops, self.test_ns), name).default def get_lib(self): - lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 - self.lib = lib - return lib + result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 + self.libraries.append(result) + return result @parametrize("mode", ("nothing", "warn")) def test_no_grad(self, mode):