mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix resurrection logic to trigger early enough (#137267)
Fixes https://github.com/pytorch/pytorch/issues/136358 The bug here is that the Tensor object is actually 2 classes: `Tensor` from `_tensor.py` and `TensorBase` from c++. Before this PR, they have the following gc methods: Tensor: - tp_clear subtype_clear - tp_traverse THPVariable_subclass_traverse - tp_dealloc THPVariable_subclass_dealloc TensorBase: - tp_clear THPVariable_clear - tp_traverse THPFunction_traverse (fake function that is just an error) - tp_dealloc object_dealloc The problem is that when clear is called on the Tensor, subtype_clear is going to clear the things owned by the "Tensor" type, in particular, its `__dict__` attribute, before delegating to the TensorBase clear where we detect that resurrection needs to happen and skip it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/137267 Approved by: https://github.com/ezyang, https://github.com/kshitij12345
This commit is contained in:
@ -10036,6 +10036,30 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertEqual(MyStorage.finalized_count, 1)
|
||||
self.assertTrue(m[0])
|
||||
|
||||
def test_tensor_ressurecting_clear(self):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/136358
|
||||
# A Tensor with custom __dict__
|
||||
# Autograd here is for the c++ reference later
|
||||
t = torch.rand(2, requires_grad=True).clone()
|
||||
t.foo = 2
|
||||
|
||||
# that is part of a cycle
|
||||
l = []
|
||||
l.append(l)
|
||||
l.append(t)
|
||||
|
||||
# Keep the Tensor alive from c++
|
||||
# Using autograd graph here (any other mean would work)
|
||||
t2 = t ** 2
|
||||
self.assertIs(t2.grad_fn._saved_self, t)
|
||||
|
||||
# Clear all python references and trigger the gc
|
||||
del t, l
|
||||
gc.collect()
|
||||
|
||||
# We used to loose the dict!
|
||||
self.assertTrue(hasattr(t2.grad_fn._saved_self, "foo"))
|
||||
|
||||
def test_tensor_slot_dealloc(self):
|
||||
|
||||
class SlotTensor1(torch.Tensor):
|
||||
|
@ -409,7 +409,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static int THPVariable_clear(THPVariable* self) {
|
||||
static int THPVariable_subclass_clear(THPVariable* self) {
|
||||
// Is it OK for an object to still be live after running
|
||||
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
|
||||
// that an object will dealloc after it's cleared. The source code explicitly
|
||||
@ -465,7 +465,7 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
|
||||
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
|
||||
// please report a bug to PyTorch. Exception raised from
|
||||
// THPVariable_clear at
|
||||
// THPVariable_subclass_clear at
|
||||
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
|
||||
// first): frame #0: c10::Error::Error(c10::SourceLocation,
|
||||
// std::__1::basic_string<char, std::__1::char_traits<char>,
|
||||
@ -475,7 +475,7 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
// c10::detail::torchInternalAssertFail(char const*, char const*,
|
||||
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
|
||||
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
|
||||
// THPVariable_clear(THPVariable*) + 412 (0x1148a547c in
|
||||
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
|
||||
// libtorch_python.dylib) frame #4:
|
||||
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
|
||||
// libtorch_python.dylib) frame #5: (anonymous
|
||||
@ -507,9 +507,15 @@ static int THPVariable_clear(THPVariable* self) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
|
||||
int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Tensor tp_traverse function was not overriden properly");
|
||||
false, "TensorBase tp_traverse function was not overriden properly");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int THPFake_clear(THPVariable* self) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "TensorBase tp_clear function was not overriden properly");
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1850,8 +1856,8 @@ PyTypeObject THPVariableType = {
|
||||
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
// Also set by metaclass
|
||||
(traverseproc)THPFunction_traverse, /* tp_traverse */
|
||||
(inquiry)THPVariable_clear, /* tp_clear */
|
||||
(traverseproc)THPFake_traverse, /* tp_traverse */
|
||||
(inquiry)THPFake_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
@ -1984,7 +1990,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
|
||||
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||
|
||||
// Finally clear out the base THPVariable
|
||||
THPVariable_clear((THPVariable*)self);
|
||||
THPVariable_subclass_clear((THPVariable*)self);
|
||||
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
|
||||
@ -2277,9 +2283,17 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||
return -1;
|
||||
}
|
||||
// It is important for all three of these to be overriden correctly for the
|
||||
// resurrection checks to properly happen. In particular, an older version
|
||||
// was not overriding tp_clear here. This lead to the default subtype_clear
|
||||
// running on the Tensor object (as only TensorBase tp_clear was custom),
|
||||
// clearing the __dict__ field, before the TensorBase custom clear was called
|
||||
// and would properly detect the resurrect.
|
||||
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
|
||||
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
|
||||
((PyTypeObject*)cls)->tp_traverse =
|
||||
(traverseproc)THPVariable_subclass_traverse;
|
||||
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
|
||||
|
||||
// Don't do anything for the base Tensor class
|
||||
if (!THPVariableClass) {
|
||||
|
Reference in New Issue
Block a user