Add pybind11 exception translator (#30588)

Summary:
Closes https://github.com/pytorch/pytorch/issues/30027

The idea here is that you can bind a function with `pybind11` in a single line and without modifying the function:
```cpp
m.def("foo", foo, py::call_guard<torch::PyWarningHandler>());
```
Where warnings are handled by the [`call_guard`](https://pybind11.readthedocs.io/en/stable/advanced/functions.html#call-guard) and exceptions are handled by the `pybind11` exception translator. To do this, I have added support for handling C++ exceptions in `torch::PyWarningHandler`'s destructor without setting the python error state before hand.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30588

Differential Revision: D19905626

Pulled By: albanD

fbshipit-source-id: 90c0a5e298b123cc0c8ab9c52c91be4e96ea47c6
This commit is contained in:
Peter Bell
2020-02-18 11:28:35 -08:00
committed by Facebook Github Bot
parent 4c8064c9e1
commit 44af8ee6cd
6 changed files with 87 additions and 77 deletions

View File

@ -721,7 +721,6 @@ class TestCppExtensionJIT(common.TestCase):
cpp_tensor_name = r"CPUDoubleType"
# Without error handling, the warnings cannot be catched
# and the Tensor type names are not cleaned
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
cpp_sources=[source],
functions=['foo'],
@ -731,12 +730,11 @@ class TestCppExtensionJIT(common.TestCase):
warn_mod.foo(t, 0)
self.assertEqual(len(w), 0)
# pybind translate all our errors to RuntimeError
with self.assertRaisesRegex(RuntimeError, cpp_tensor_name):
with self.assertRaisesRegex(TypeError, t.type()):
warn_mod.foo(t, 1)
self.assertEqual(len(w), 0)
with self.assertRaisesRegex(RuntimeError, "bad argument to internal function|python_error"):
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
warn_mod.foo(t, 2)
self.assertEqual(len(w), 0)

View File

@ -4,6 +4,7 @@
#include <utility>
#include <vector>
#include <cstdarg>
#include <exception>
#include <torch/csrc/THP.h>
@ -144,7 +145,8 @@ void PyWarningHandler::process(
};
PyWarningHandler::PyWarningHandler() noexcept(true):
prev_handler_(c10::Warning::get_warning_handler()) {
prev_handler_(c10::Warning::get_warning_handler()),
in_exception_(false) {
c10::Warning::set_warning_handler(this);
}
@ -154,13 +156,8 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
c10::Warning::set_warning_handler(prev_handler_);
if(warning_buffer_.size() > 0) {
pybind11::gil_scoped_acquire gil;
PyObject *ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
if(ptype) {
// A python error happened after the warning
if(in_exception_) {
// An error happened after the warning
// Simply handle with the previous handler
for(const auto& warning: warning_buffer_) {
auto source_location = warning.first;
@ -168,11 +165,8 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
c10::Warning::warn(source_location, msg);
}
warning_buffer_.clear();
// The parent function already returns an error
// We only restore the error and exit the
// destructor normally
PyErr_Restore(ptype, pvalue, ptraceback);
} else {
pybind11::gil_scoped_acquire gil;
auto result = 0;
for(const auto& warning: warning_buffer_) {
auto source_location = warning.first;
@ -205,4 +199,3 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
} // namespace torch

View File

@ -13,6 +13,7 @@
#include <torch/csrc/jit/script/jit_exception.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <c10/util/StringUtil.h>
#include <ATen/detail/FunctionTraits.h>
/// NOTE [ Conversion Cpp Python Warning ]
/// The warning handler cannot set python warnings immediately
@ -41,10 +42,12 @@
#define HANDLE_TH_ERRORS \
try { \
torch::PyWarningHandler __enforce_warning_buffer; \
try{
try {
// Only catch torch-specific exceptions
#define CATCH_TH_ERRORS(retstmnt) \
catch (python_error & e) { \
e.restore(); \
retstmnt; \
} \
catch (const c10::IndexError& e) { \
@ -61,7 +64,10 @@
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg.c_str()); \
retstmnt; \
} \
}
#define CATCH_ALL_ERRORS(retstmnt) \
CATCH_TH_ERRORS(retstmnt) \
catch (const std::exception& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
@ -70,46 +76,30 @@
#define END_HANDLE_TH_ERRORS_PYBIND \
} \
catch (py::error_already_set & e) { \
/* Unpack already stored error to be detectable by warning code */ \
e.restore(); \
catch(...) { \
__enforce_warning_buffer.set_in_exception(); \
throw; \
} \
catch (py::builtin_exception & e) { \
/* Unpack already stored error to be detectable by warning code */ \
e.set_error(); \
throw; \
} \
catch (torch::jit::JITException & e) { \
/* Special case for JITException that are explicitly unpacked by */\
/* pybind. Set a temporary python error to be detectable by */ \
/* warning code */ \
PyErr_SetString(PyExc_RuntimeError, "JITException"); \
throw; \
} \
CATCH_TH_ERRORS(throw) \
} \
catch (py::error_already_set & e) { \
/* Repack already stored error */ \
throw py::error_already_set(); \
} \
catch (py::builtin_exception & e) { \
/* Repack already stored error */ \
throw py::error_already_set(); \
} \
catch (torch::jit::JITException & e) { \
/* Special case for JITException that are explicitly unpacked by */ \
/* pybind. Clear the temporary error message we used */ \
PyErr_Clear(); \
throw; \
} \
CATCH_TH_ERRORS(throw py::error_already_set())
catch (py::builtin_exception & e) { \
throw; \
} \
catch (torch::jit::JITException & e) { \
throw; \
} \
CATCH_ALL_ERRORS(throw py::error_already_set())
#define END_HANDLE_TH_ERRORS_RET(retval) \
} \
CATCH_TH_ERRORS(return retval) \
catch(...) { \
__enforce_warning_buffer.set_in_exception(); \
throw; \
} \
} \
CATCH_TH_ERRORS(return retval)
CATCH_ALL_ERRORS(return retval)
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
@ -271,6 +261,16 @@ public:
void process(const at::SourceLocation &source_location,
const std::string &msg) override;
/** Call if an exception has been thrown
* Necessary to determine if it is safe to throw from the desctructor since
* std::uncaught_exception is buggy on some platforms and generally
* unreliable across dynamic library calls.
*/
void set_in_exception() {
in_exception_ = true;
}
private:
using warning_buffer_t =
std::vector<std::pair<c10::SourceLocation, std::string>>;
@ -278,6 +278,34 @@ private:
warning_buffer_t warning_buffer_;
at::WarningHandler* prev_handler_;
bool in_exception_;
};
namespace detail {
template <typename Func, size_t i>
using Arg = typename function_traits<Func>::template arg<i>::type;
template <typename Func, size_t ...Is>
auto wrap_pybind_function_impl_(Func&& f, std::index_sequence<Is...>) {
using traits = function_traits<Func>;
namespace py = pybind11;
// f=f is needed to handle function references on older compilers
return [f=f](Arg<Func, Is> ...args) -> typename traits::result_type {
HANDLE_TH_ERRORS
return f(std::forward<Arg<Func, Is>>(args)...);
END_HANDLE_TH_ERRORS_PYBIND
};
}
} // namespace detail
// Wrap a function with TH error and warning handling.
// Returns a function object suitable for registering with pybind11.
template <typename Func>
auto wrap_pybind_function(Func&& f) {
using traits = function_traits<Func>;
return torch::detail::wrap_pybind_function_impl_(
std::forward<Func>(f), std::make_index_sequence<traits::arity>{});
}
} // namespace torch

View File

@ -742,6 +742,16 @@ PyObject* initModule() {
// setting up TH Errors so that they throw C++ exceptions
at::init();
// Automatically translate errors thrown from pybind11 functions
py::register_exception_translator([](std::exception_ptr e) { // NOLINT
try {
if (e) {
std::rethrow_exception(e);
}
}
CATCH_TH_ERRORS()
});
auto py_module = py::reinterpret_borrow<py::module>(module);
py_module.def("_demangle", &c10::demangle);
py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);

View File

@ -111,7 +111,12 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
for (const auto& root : roots) {
variables.emplace_back(root);
}
DistEngine::getInstance().execute(variables, retainGraph);
try {
DistEngine::getInstance().execute(variables, retainGraph);
} catch (python_error & e) {
// FIXME: crashes if exception type is not RuntimeError
throw std::runtime_error(e.what());
}
},
R"(
backward(roots: List[Tensor], retain_graph = False) -> None

View File

@ -101,26 +101,6 @@ COMMON_NVCC_FLAGS = [
'--expt-relaxed-constexpr'
]
# See comment in load_inline for more information
# The goal is to be able to call the safe version of the
# function exactly as if it was the original one.
# We need to create a pointer to this new function to give
# it to pybind later.
SAFE_FUNCTION_DEFINITION = '''
#include <functional>
template <typename Ret, typename ...Args>
auto _get_safe_version(Ret (*f)(Args...)) -> std::function<Ret(Args...)> {{
return [f](Args&& ...args) -> Ret {{
HANDLE_TH_ERRORS
return f(std::forward<Args>(args)...);
END_HANDLE_TH_ERRORS_PYBIND
}};
}}
'''
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
@ -945,11 +925,6 @@ def load_inline(name,
cpp_sources.insert(0, '#include <torch/extension.h>')
# Adds a new `_get_safe_version(foo)` function that returns a new function
# that performs the same operation as `foo` but with pytorch error handling
# macros.
cpp_sources.append(SAFE_FUNCTION_DEFINITION)
# If `functions` is supplied, we create the pybind11 bindings for the user.
# Here, `functions` is (or becomes, after some processing) a map from
# function names to function docstrings.
@ -967,8 +942,9 @@ def load_inline(name,
type(functions)))
for function_name, docstring in functions.items():
if with_pytorch_error_handling:
module_def.append('m.def("{0}", _get_safe_version({0}), "{1}");'.format(
function_name, docstring))
module_def.append(
'm.def("{0}", torch::wrap_pybind_function({0}), "{1}");'
.format(function_name, docstring))
else:
module_def.append('m.def("{0}", {0}, "{1}");'.format(function_name, docstring))
module_def.append('}')