Fix resurrection logic to trigger early enough (#137267)

Fixes https://github.com/pytorch/pytorch/issues/136358

The bug here is that the Tensor object is actually 2 classes: `Tensor` from `_tensor.py` and `TensorBase` from c++.

Before this PR, they have the following gc methods:
Tensor:
 - tp_clear subtype_clear
 - tp_traverse THPVariable_subclass_traverse
 - tp_dealloc THPVariable_subclass_dealloc

TensorBase:
- tp_clear THPVariable_clear
- tp_traverse THPFunction_traverse (fake function that is just an error)
- tp_dealloc object_dealloc

The problem is that when clear is called on the Tensor, subtype_clear is going to clear the things owned by the "Tensor" type, in particular, its `__dict__` attribute, before delegating to the TensorBase clear where we detect that resurrection needs to happen and skip it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137267
Approved by: https://github.com/ezyang, https://github.com/kshitij12345
This commit is contained in:
albanD
2024-10-04 18:23:03 +00:00
committed by PyTorch MergeBot
parent bd48933323
commit c0deec120f
2 changed files with 46 additions and 8 deletions

View File

@ -10036,6 +10036,30 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertEqual(MyStorage.finalized_count, 1)
self.assertTrue(m[0])
def test_tensor_ressurecting_clear(self):
# Regression test for https://github.com/pytorch/pytorch/issues/136358
# A Tensor with custom __dict__
# Autograd here is for the c++ reference later
t = torch.rand(2, requires_grad=True).clone()
t.foo = 2
# that is part of a cycle
l = []
l.append(l)
l.append(t)
# Keep the Tensor alive from c++
# Using autograd graph here (any other mean would work)
t2 = t ** 2
self.assertIs(t2.grad_fn._saved_self, t)
# Clear all python references and trigger the gc
del t, l
gc.collect()
# We used to loose the dict!
self.assertTrue(hasattr(t2.grad_fn._saved_self, "foo"))
def test_tensor_slot_dealloc(self):
class SlotTensor1(torch.Tensor):

View File

@ -409,7 +409,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
return true;
}
static int THPVariable_clear(THPVariable* self) {
static int THPVariable_subclass_clear(THPVariable* self) {
// Is it OK for an object to still be live after running
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
// that an object will dealloc after it's cleared. The source code explicitly
@ -465,7 +465,7 @@ static int THPVariable_clear(THPVariable* self) {
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
// please report a bug to PyTorch. Exception raised from
// THPVariable_clear at
// THPVariable_subclass_clear at
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
// first): frame #0: c10::Error::Error(c10::SourceLocation,
// std::__1::basic_string<char, std::__1::char_traits<char>,
@ -475,7 +475,7 @@ static int THPVariable_clear(THPVariable* self) {
// c10::detail::torchInternalAssertFail(char const*, char const*,
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
// THPVariable_clear(THPVariable*) + 412 (0x1148a547c in
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
// libtorch_python.dylib) frame #4:
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
// libtorch_python.dylib) frame #5: (anonymous
@ -507,9 +507,15 @@ static int THPVariable_clear(THPVariable* self) {
return 0;
}
int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
TORCH_INTERNAL_ASSERT(
false, "Tensor tp_traverse function was not overriden properly");
false, "TensorBase tp_traverse function was not overriden properly");
return 0;
}
int THPFake_clear(THPVariable* self) {
TORCH_INTERNAL_ASSERT(
false, "TensorBase tp_clear function was not overriden properly");
return 0;
}
@ -1850,8 +1856,8 @@ PyTypeObject THPVariableType = {
Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
// Also set by metaclass
(traverseproc)THPFunction_traverse, /* tp_traverse */
(inquiry)THPVariable_clear, /* tp_clear */
(traverseproc)THPFake_traverse, /* tp_traverse */
(inquiry)THPFake_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
@ -1984,7 +1990,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
// Finally clear out the base THPVariable
THPVariable_clear((THPVariable*)self);
THPVariable_subclass_clear((THPVariable*)self);
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
Py_TYPE(self)->tp_free(self);
@ -2277,9 +2283,17 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
// It is important for all three of these to be overriden correctly for the
// resurrection checks to properly happen. In particular, an older version
// was not overriding tp_clear here. This lead to the default subtype_clear
// running on the Tensor object (as only TensorBase tp_clear was custom),
// clearing the __dict__ field, before the TensorBase custom clear was called
// and would properly detect the resurrect.
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
((PyTypeObject*)cls)->tp_traverse =
(traverseproc)THPVariable_subclass_traverse;
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
// Don't do anything for the base Tensor class
if (!THPVariableClass) {