Files
pytorch/torch/csrc/autograd/python_anomaly_mode.cpp
Jeffrey Wan 2e8e560cdf Fix anomaly mode memory leak (#51610)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/51349

The memory leak happens when 1) `create_graph` is True AND 2) detect anomaly mode is on. When a backward node's constructor is called during backward, the current evaluating node is assigned as a "parent" of the created node. The code that assigns the parent encounters the below issue:

`functionToPyObject(parent_node)` returns a new PyObject (with refcount 1) or if PyObject already exists, increments its refcount by 1. However [PyDict_SetItem](1b55b65638/Objects/dictobject.c (L1532)) calls into [insertdict](https://github.com/python/cpython/blob/v3.8.1/Objects/dictobject.c#L1034) which increments refcount again. This means that when dict is destroyed, the refcount of the PyObject is at least one. This keeps `parent_node` (the backward function) alive, which then keeps the saved tensor alive.

Similar calls in the codebase to `functionToPyObject` won't require Py_DECREF if it is then passed into a tuple (instead of dict), because the analogous PyTuple_SetItem call does not increment refcount.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51610

Reviewed By: albanD

Differential Revision: D26240336

Pulled By: soulitzer

fbshipit-source-id: 2854528f66fab9dbce448f8a7ba732ce386a7310
2021-02-04 11:53:37 -08:00

114 lines
3.7 KiB
C++

#include <torch/csrc/autograd/python_anomaly_mode.h>
#include <c10/util/Exception.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/python_strings.h>
#include <iostream>
namespace torch { namespace autograd {
void PyAnomalyMetadata::store_stack() {
pybind11::gil_scoped_acquire gil;
THPObjectPtr mod(PyImport_ImportModule("traceback"));
if (!mod) {
throw python_error();
}
THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
if (!list) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
throw python_error();
}
}
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
pybind11::gil_scoped_acquire gil;
if (!PyDict_Check(dict())) {
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
}
PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
_print_stack(trace_stack, current_node_name, false);
PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
// if there is no "parent_" in metadata, then it means this metadata's node
// is the root and stop printing the traceback
while (pyparent) {
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
if (!parent_metadata) {
throw python_error();
}
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
if (!parent_name_pyobj) {
throw python_error();
}
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
if (!parent_name_char) {
throw python_error();
}
const std::string parent_name(parent_name_char);
PyObject* parent_stack = PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
_print_stack(parent_stack, parent_name, true);
// get the parent of this node, if this node is a root, pyparent is simply null
pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
}
}
void PyAnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
// assign the python object of parent_node in metadata["parent_"]
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
// is not in metadata)
pybind11::gil_scoped_acquire gil;
if (!parent_node) return;
THPObjectPtr parent_node_(functionToPyObject(parent_node));
if (!parent_node_) {
throw python_error();
}
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
throw python_error();
}
}
void _print_stack(PyObject* stack, const std::string& current_node_name, bool is_parent) {
if (!stack) {
TORCH_WARN("Error detected in ", current_node_name, ". ",
"No forward pass information available. Enable detect anomaly "
"during forward pass for more information.");
return;
}
THPObjectPtr empty_string(PyUnicode_FromString(""));
if (!empty_string) {
throw python_error();
}
// stack is a list of Python strings ending with newlines. Use join to convert
// to a single string.
THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
if (!msg) {
throw python_error();
}
if (!is_parent) {
TORCH_WARN("Error detected in ", current_node_name, ". ",
"Traceback of forward call that caused the error:\n",
THPUtils_unpackString(msg.get()));
} else {
TORCH_WARN("\n\n",
"Previous calculation was induced by ", current_node_name, ". "
"Traceback of forward call that induced the previous calculation:\n",
THPUtils_unpackString(msg.get()));
}
}
}}