Fix segfault in backward

This commit is contained in:
albanD
2017-09-13 13:10:23 +01:00
committed by Soumith Chintala
parent d910a94b2b
commit 2356ee41b7
2 changed files with 23 additions and 2 deletions

View File

@ -103,8 +103,11 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
Py_VISIT(self->data);
Py_VISIT(self->backward_hooks);
if (self->cdata.defined()) {
if (auto fn = dynamic_cast<PyFunction*>(self->cdata.grad_fn().get())) {
Py_VISIT(fn->obj);
// Only visit this if we actually own it (no one else use the shared pointer)
if (self->cdata.grad_fn().use_count() == 1) {
if (auto fn = dynamic_cast<PyFunction*>(self->cdata.grad_fn().get())) {
Py_VISIT(fn->obj);
}
}
for (auto& hook : self->cdata.hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {