mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[autograd] allow PyNode to persist error message (#34845)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34845 This PR allows PyNode to persist the error message so that any pure C++ thread that runs autograd with custom Python autograd function can successfully catpure the error message without maintaining a initial PyThreadState. Test Plan: Imported from OSS Differential Revision: D20480685 Pulled By: wanchaol fbshipit-source-id: 0488ea5a4df9a33b53ac5d0d59000c41ab6cb748
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8346959f38
commit
9e7821ee82
@ -44,11 +44,16 @@ PyObject *THPFunctionClass = nullptr;
|
||||
|
||||
namespace torch { namespace autograd {
|
||||
|
||||
void PyNode::throw_python_error() {
|
||||
python_error err;
|
||||
err.persist();
|
||||
throw err;
|
||||
}
|
||||
auto PyNode::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
||||
THPObjectPtr pyInputs(PyTuple_New(inputs.size()));
|
||||
if (!pyInputs) throw python_error();
|
||||
if (!pyInputs) throw_python_error();
|
||||
|
||||
for (size_t i = 0; i != inputs.size(); ++i) {
|
||||
PyTuple_SET_ITEM(pyInputs.get(), i, THPVariable_Wrap(inputs[i]));
|
||||
@ -56,7 +61,7 @@ auto PyNode::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
|
||||
THPObjectPtr r(PyObject_CallMethod(
|
||||
obj, "_do_backward", "OO", pyInputs.get(), Py_True));
|
||||
if (!r) throw python_error();
|
||||
if (!r) throw_python_error();
|
||||
|
||||
auto num_outputs = PyTuple_GET_SIZE(r.get());
|
||||
tensor_list tensor_results(num_outputs);
|
||||
@ -104,7 +109,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
// Massage a C++ variable_list into a Python arguments tuple
|
||||
auto num_inputs = inputs.size();
|
||||
THPObjectPtr pyInputs(PyTuple_New(num_inputs));
|
||||
if (!pyInputs) throw python_error();
|
||||
if (!pyInputs) throw_python_error();
|
||||
auto& output_info = py_fn->output_info;
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
PyObject* input;
|
||||
@ -113,14 +118,14 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
} else {
|
||||
input = THPVariable_Wrap(output_info[i].zeros(_device_guard));
|
||||
}
|
||||
if (!input) throw python_error();
|
||||
if (!input) throw_python_error();
|
||||
PyTuple_SET_ITEM(pyInputs.get(), i, input);
|
||||
}
|
||||
|
||||
THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
|
||||
if (!apply_fn) throw python_error();
|
||||
if (!apply_fn) throw_python_error();
|
||||
THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
|
||||
if (!r) throw python_error();
|
||||
if (!r) throw_python_error();
|
||||
ensure_tuple(r);
|
||||
|
||||
auto& is_variable_input = py_fn->is_variable_input;
|
||||
@ -136,7 +141,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
if (all_none) {
|
||||
num_outputs = num_forward_inputs;
|
||||
r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
|
||||
if (!r) throw python_error();
|
||||
if (!r) throw_python_error();
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,9 +194,9 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list {
|
||||
auto PyNode::is_traceable() -> bool {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
THPObjectPtr forward_class {PyObject_GetAttrString(obj, "_forward_cls")};
|
||||
if (!forward_class) throw python_error();
|
||||
if (!forward_class) throw_python_error();
|
||||
THPObjectPtr traceable_py_bool {PyObject_GetAttrString(forward_class, "is_traceable")};
|
||||
if (!traceable_py_bool) throw python_error();
|
||||
if (!traceable_py_bool) throw_python_error();
|
||||
return traceable_py_bool == Py_True;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user