[autograd] allow PyNode to persist error message (#34845)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34845

This PR allows PyNode to persist the error message so that any pure C++
thread that runs autograd with custom Python autograd function can successfully
catpure the error message without maintaining a initial PyThreadState.

Test Plan: Imported from OSS

Differential Revision: D20480685

Pulled By: wanchaol

fbshipit-source-id: 0488ea5a4df9a33b53ac5d0d59000c41ab6cb748
This commit is contained in:
Wanchao Liang
2020-03-23 21:51:27 -07:00
committed by Facebook GitHub Bot
parent 8346959f38
commit 9e7821ee82
3 changed files with 49 additions and 9 deletions

View File

@ -44,11 +44,16 @@ PyObject *THPFunctionClass = nullptr;
namespace torch { namespace autograd {
void PyNode::throw_python_error() {
python_error err;
err.persist();
throw err;
}
auto PyNode::legacy_apply(const variable_list& inputs) -> variable_list {
pybind11::gil_scoped_acquire gil;
THPObjectPtr pyInputs(PyTuple_New(inputs.size()));
if (!pyInputs) throw python_error();
if (!pyInputs) throw_python_error();
for (size_t i = 0; i != inputs.size(); ++i) {
PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i]));
@ -56,7 +61,7 @@ auto PyNode::legacy_apply(const variable_list& inputs) -> variable_list {
THPObjectPtr r(PyObject_CallMethod(
obj, "_do_backward", "OO", pyInputs.get(), Py_True));
if (!r) throw python_error();
if (!r) throw_python_error();
auto num_outputs = PyTuple_GET_SIZE(r.get());
tensor_list tensor_results(num_outputs);
@ -104,7 +109,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
// Massage a C++ variable_list into a Python arguments tuple
auto num_inputs = inputs.size();
THPObjectPtr pyInputs(PyTuple_New(num_inputs));
if (!pyInputs) throw python_error();
if (!pyInputs) throw_python_error();
auto& output_info = py_fn->output_info;
for (size_t i = 0; i < num_inputs; ++i) {
PyObject* input;
@ -113,14 +118,14 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
} else {
input = THPVariable_Wrap(output_info[i].zeros(_device_guard));
}
if (!input) throw python_error();
if (!input) throw_python_error();
PyTuple_SET_ITEM(pyInputs.get(), i, input);
}
THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
if (!apply_fn) throw python_error();
if (!apply_fn) throw_python_error();
THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
if (!r) throw python_error();
if (!r) throw_python_error();
ensure_tuple(r);
auto& is_variable_input = py_fn->is_variable_input;
@ -136,7 +141,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
if (all_none) {
num_outputs = num_forward_inputs;
r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
if (!r) throw python_error();
if (!r) throw_python_error();
}
}
@ -189,9 +194,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
auto PyNode::is_traceable() -> bool {
pybind11::gil_scoped_acquire gil;
THPObjectPtr forward_class {PyObject_GetAttrString(obj, "_forward_cls")};
if (!forward_class) throw python_error();
if (!forward_class) throw_python_error();
THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class, "is_traceable")};
if (!traceable_py_bool) throw python_error();
if (!traceable_py_bool) throw_python_error();
return traceable_py_bool == Py_True;
}