mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #133180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133295 Approved by: https://github.com/Skylion007
128 lines
3.8 KiB
C++
128 lines
3.8 KiB
C++
#include <c10/util/Exception.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <torch/csrc/Exceptions.h>
|
|
#include <torch/csrc/autograd/python_anomaly_mode.h>
|
|
#include <torch/csrc/autograd/python_cpp_function.h>
|
|
#include <torch/csrc/python_headers.h>
|
|
#include <torch/csrc/utils/object_ptr.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_strings.h>
|
|
|
|
namespace torch::autograd {
|
|
|
|
void PyAnomalyMetadata::store_stack() {
|
|
pybind11::gil_scoped_acquire gil;
|
|
THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
|
|
if (!mod) {
|
|
throw python_error();
|
|
}
|
|
|
|
THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
|
|
if (!list) {
|
|
throw python_error();
|
|
}
|
|
|
|
if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
|
|
pybind11::gil_scoped_acquire gil;
|
|
if (!PyDict_Check(dict())) {
|
|
throw std::runtime_error("Anomaly metadata is not a python dictionary.");
|
|
}
|
|
PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
|
|
_print_stack(trace_stack, current_node_name, false);
|
|
PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
|
|
|
|
// if there is no "parent_" in metadata, then it means this metadata's node
|
|
// is the root and stop printing the traceback
|
|
while (pyparent) {
|
|
THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
|
|
if (!parent_metadata) {
|
|
throw python_error();
|
|
}
|
|
THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
|
|
if (!parent_name_pyobj) {
|
|
throw python_error();
|
|
}
|
|
const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
|
|
if (!parent_name_char) {
|
|
throw python_error();
|
|
}
|
|
const std::string parent_name(parent_name_char);
|
|
PyObject* parent_stack =
|
|
PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
|
|
_print_stack(parent_stack, parent_name, true);
|
|
// get the parent of this node, if this node is a root, pyparent is simply
|
|
// null
|
|
pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
|
|
}
|
|
}
|
|
|
|
void PyAnomalyMetadata::assign_parent(
|
|
const std::shared_ptr<Node>& parent_node) {
|
|
// assign the python object of parent_node in metadata["parent_"]
|
|
// if parent_node is nullptr, then do nothing (it can mean that "parent_" key
|
|
// is not in metadata)
|
|
|
|
pybind11::gil_scoped_acquire gil;
|
|
if (!parent_node)
|
|
return;
|
|
|
|
THPObjectPtr parent_node_(functionToPyObject(parent_node));
|
|
if (!parent_node_) {
|
|
throw python_error();
|
|
}
|
|
if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
|
|
throw python_error();
|
|
}
|
|
}
|
|
|
|
void _print_stack(
|
|
PyObject* stack,
|
|
const std::string& current_node_name,
|
|
bool is_parent) {
|
|
if (!stack) {
|
|
TORCH_WARN(
|
|
"Error detected in ",
|
|
current_node_name,
|
|
". ",
|
|
"No forward pass information available. Enable detect anomaly "
|
|
"during forward pass for more information.");
|
|
return;
|
|
}
|
|
|
|
THPObjectPtr empty_string(PyUnicode_FromString(""));
|
|
if (!empty_string) {
|
|
throw python_error();
|
|
}
|
|
|
|
// stack is a list of Python strings ending with newlines. Use join to convert
|
|
// to a single string.
|
|
THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
|
|
if (!msg) {
|
|
throw python_error();
|
|
}
|
|
|
|
if (!is_parent) {
|
|
TORCH_WARN(
|
|
"Error detected in ",
|
|
current_node_name,
|
|
". ",
|
|
"Traceback of forward call that caused the error:\n",
|
|
THPUtils_unpackString(msg.get()));
|
|
} else {
|
|
TORCH_WARN(
|
|
"\n\n",
|
|
"Previous calculation was induced by ",
|
|
current_node_name,
|
|
". "
|
|
"Traceback of forward call that induced the previous calculation:\n",
|
|
THPUtils_unpackString(msg.get()));
|
|
}
|
|
}
|
|
|
|
} // namespace torch::autograd
|