Fix _fix_weakref memory leak (#90823)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90823
Approved by: https://github.com/eellison, https://github.com/albanD
This commit is contained in:
Edward Z. Yang
2022-12-15 06:08:05 +08:00
committed by PyTorch MergeBot
parent d19791e4cd
commit 283cf718ed
2 changed files with 20 additions and 3 deletions

View File

@ -8329,7 +8329,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del fin_tensor
self.assertTrue(m[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_tensor_weakref_dealloc(self):
x = torch.empty(2)
@ -8445,7 +8445,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(m1[0])
self.assertTrue(m2[0])
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_dead_weak_ref(self):
x = torch.empty(2)
w_x = weakref.ref(x)
@ -8475,6 +8475,23 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
del y
x.sigmoid()
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_fix_weakref_no_leak(self):
import weakref
called = False
a = torch.randn(1)
def callback(w):
nonlocal called
called = True
wa = weakref.ref(a, callback)
a._fix_weakref()
del a
self.assertTrue(called)
# FIXME: move to test_linalg
@torch.inference_mode()
def test_bmm_multithreaded(self):

View File

@ -680,7 +680,7 @@ PyObject* THPVariable_pynew(
static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
const auto& var = THPVariable_Unpack(self);
THPVariable_Wrap(var);
Py_DECREF(THPVariable_Wrap(var));
Py_RETURN_NONE;
}