mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix segfault in backward
This commit is contained in:
@ -103,8 +103,11 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
Py_VISIT(self->data);
|
||||
Py_VISIT(self->backward_hooks);
|
||||
if (self->cdata.defined()) {
|
||||
if (auto fn = dynamic_cast<PyFunction*>(self->cdata.grad_fn().get())) {
|
||||
Py_VISIT(fn->obj);
|
||||
// Only visit this if we actually own it (no one else use the shared pointer)
|
||||
if (self->cdata.grad_fn().use_count() == 1) {
|
||||
if (auto fn = dynamic_cast<PyFunction*>(self->cdata.grad_fn().get())) {
|
||||
Py_VISIT(fn->obj);
|
||||
}
|
||||
}
|
||||
for (auto& hook : self->cdata.hooks()) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
|
||||
|
Reference in New Issue
Block a user