mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix TestAutogradFallback flaky tests under Dynamo: migrate to lib._destroy() (#159443)
under dynamo, the libraries couldn't properly be cleared unless we manually did `gc.collect()`, but that's slow. it also worked if we just used the _destroy() method to tear down FIXES #159398 #159349 #159254 #159237 #159153 #159114 #159040 #158910 #158841 #158763 #158735 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159443 Approved by: https://github.com/zou3519, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
7821fbc560
commit
644fee2610
@ -6,7 +6,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.library import _scoped_library, Library
|
||||
from torch.library import _scoped_library
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -28,20 +28,24 @@ def autograd_fallback_mode(mode):
|
||||
class TestAutogradFallback(TestCase):
|
||||
test_ns = "_test_autograd_fallback"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.libraries = []
|
||||
|
||||
def tearDown(self):
|
||||
if hasattr(torch.ops, self.test_ns):
|
||||
delattr(torch.ops, self.test_ns)
|
||||
if hasattr(self, "lib"):
|
||||
del self.lib.m
|
||||
del self.lib
|
||||
for lib in self.libraries:
|
||||
lib._destroy()
|
||||
del self.libraries
|
||||
|
||||
def get_op(self, name):
|
||||
return getattr(getattr(torch.ops, self.test_ns), name).default
|
||||
|
||||
def get_lib(self):
|
||||
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||
self.lib = lib
|
||||
return lib
|
||||
result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||
self.libraries.append(result)
|
||||
return result
|
||||
|
||||
@parametrize("mode", ("nothing", "warn"))
|
||||
def test_no_grad(self, mode):
|
||||
|
Reference in New Issue
Block a user