mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pybind11 exception translator (#30588)
Summary: Closes https://github.com/pytorch/pytorch/issues/30027 The idea here is that you can bind a function with `pybind11` in a single line and without modifying the function: ```cpp m.def("foo", foo, py::call_guard<torch::PyWarningHandler>()); ``` Where warnings are handled by the [`call_guard`](https://pybind11.readthedocs.io/en/stable/advanced/functions.html#call-guard) and exceptions are handled by the `pybind11` exception translator. To do this, I have added support for handling C++ exceptions in `torch::PyWarningHandler`'s destructor without setting the python error state before hand. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30588 Differential Revision: D19905626 Pulled By: albanD fbshipit-source-id: 90c0a5e298b123cc0c8ab9c52c91be4e96ea47c6
This commit is contained in:
committed by
Facebook Github Bot
parent
4c8064c9e1
commit
44af8ee6cd
@ -721,7 +721,6 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
cpp_tensor_name = r"CPUDoubleType"
|
||||
|
||||
# Without error handling, the warnings cannot be catched
|
||||
# and the Tensor type names are not cleaned
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
|
||||
cpp_sources=[source],
|
||||
functions=['foo'],
|
||||
@ -731,12 +730,11 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
warn_mod.foo(t, 0)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# pybind translate all our errors to RuntimeError
|
||||
with self.assertRaisesRegex(RuntimeError, cpp_tensor_name):
|
||||
with self.assertRaisesRegex(TypeError, t.type()):
|
||||
warn_mod.foo(t, 1)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "bad argument to internal function|python_error"):
|
||||
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
|
||||
warn_mod.foo(t, 2)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cstdarg>
|
||||
#include <exception>
|
||||
|
||||
#include <torch/csrc/THP.h>
|
||||
|
||||
@ -144,7 +145,8 @@ void PyWarningHandler::process(
|
||||
};
|
||||
|
||||
PyWarningHandler::PyWarningHandler() noexcept(true):
|
||||
prev_handler_(c10::Warning::get_warning_handler()) {
|
||||
prev_handler_(c10::Warning::get_warning_handler()),
|
||||
in_exception_(false) {
|
||||
c10::Warning::set_warning_handler(this);
|
||||
}
|
||||
|
||||
@ -154,13 +156,8 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
|
||||
c10::Warning::set_warning_handler(prev_handler_);
|
||||
|
||||
if(warning_buffer_.size() > 0) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
||||
PyObject *ptype, *pvalue, *ptraceback;
|
||||
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
|
||||
|
||||
if(ptype) {
|
||||
// A python error happened after the warning
|
||||
if(in_exception_) {
|
||||
// An error happened after the warning
|
||||
// Simply handle with the previous handler
|
||||
for(const auto& warning: warning_buffer_) {
|
||||
auto source_location = warning.first;
|
||||
@ -168,11 +165,8 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
|
||||
c10::Warning::warn(source_location, msg);
|
||||
}
|
||||
warning_buffer_.clear();
|
||||
// The parent function already returns an error
|
||||
// We only restore the error and exit the
|
||||
// destructor normally
|
||||
PyErr_Restore(ptype, pvalue, ptraceback);
|
||||
} else {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
auto result = 0;
|
||||
for(const auto& warning: warning_buffer_) {
|
||||
auto source_location = warning.first;
|
||||
@ -205,4 +199,3 @@ PyWarningHandler::~PyWarningHandler() noexcept(false) {
|
||||
|
||||
|
||||
} // namespace torch
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <torch/csrc/jit/script/jit_exception.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
|
||||
/// NOTE [ Conversion Cpp Python Warning ]
|
||||
/// The warning handler cannot set python warnings immediately
|
||||
@ -41,10 +42,12 @@
|
||||
#define HANDLE_TH_ERRORS \
|
||||
try { \
|
||||
torch::PyWarningHandler __enforce_warning_buffer; \
|
||||
try{
|
||||
try {
|
||||
|
||||
// Only catch torch-specific exceptions
|
||||
#define CATCH_TH_ERRORS(retstmnt) \
|
||||
catch (python_error & e) { \
|
||||
e.restore(); \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10::IndexError& e) { \
|
||||
@ -61,7 +64,10 @@
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(e.python_type(), msg.c_str()); \
|
||||
retstmnt; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CATCH_ALL_ERRORS(retstmnt) \
|
||||
CATCH_TH_ERRORS(retstmnt) \
|
||||
catch (const std::exception& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
|
||||
@ -70,46 +76,30 @@
|
||||
|
||||
#define END_HANDLE_TH_ERRORS_PYBIND \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
/* Unpack already stored error to be detectable by warning code */ \
|
||||
e.restore(); \
|
||||
catch(...) { \
|
||||
__enforce_warning_buffer.set_in_exception(); \
|
||||
throw; \
|
||||
} \
|
||||
catch (py::builtin_exception & e) { \
|
||||
/* Unpack already stored error to be detectable by warning code */ \
|
||||
e.set_error(); \
|
||||
throw; \
|
||||
} \
|
||||
catch (torch::jit::JITException & e) { \
|
||||
/* Special case for JITException that are explicitly unpacked by */\
|
||||
/* pybind. Set a temporary python error to be detectable by */ \
|
||||
/* warning code */ \
|
||||
PyErr_SetString(PyExc_RuntimeError, "JITException"); \
|
||||
throw; \
|
||||
} \
|
||||
CATCH_TH_ERRORS(throw) \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
/* Repack already stored error */ \
|
||||
throw py::error_already_set(); \
|
||||
} \
|
||||
catch (py::builtin_exception & e) { \
|
||||
/* Repack already stored error */ \
|
||||
throw py::error_already_set(); \
|
||||
} \
|
||||
catch (torch::jit::JITException & e) { \
|
||||
/* Special case for JITException that are explicitly unpacked by */ \
|
||||
/* pybind. Clear the temporary error message we used */ \
|
||||
PyErr_Clear(); \
|
||||
throw; \
|
||||
} \
|
||||
CATCH_TH_ERRORS(throw py::error_already_set())
|
||||
catch (py::builtin_exception & e) { \
|
||||
throw; \
|
||||
} \
|
||||
catch (torch::jit::JITException & e) { \
|
||||
throw; \
|
||||
} \
|
||||
CATCH_ALL_ERRORS(throw py::error_already_set())
|
||||
|
||||
#define END_HANDLE_TH_ERRORS_RET(retval) \
|
||||
} \
|
||||
CATCH_TH_ERRORS(return retval) \
|
||||
catch(...) { \
|
||||
__enforce_warning_buffer.set_in_exception(); \
|
||||
throw; \
|
||||
} \
|
||||
} \
|
||||
CATCH_TH_ERRORS(return retval)
|
||||
CATCH_ALL_ERRORS(return retval)
|
||||
|
||||
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
|
||||
|
||||
@ -271,6 +261,16 @@ public:
|
||||
void process(const at::SourceLocation &source_location,
|
||||
const std::string &msg) override;
|
||||
|
||||
/** Call if an exception has been thrown
|
||||
|
||||
* Necessary to determine if it is safe to throw from the desctructor since
|
||||
* std::uncaught_exception is buggy on some platforms and generally
|
||||
* unreliable across dynamic library calls.
|
||||
*/
|
||||
void set_in_exception() {
|
||||
in_exception_ = true;
|
||||
}
|
||||
|
||||
private:
|
||||
using warning_buffer_t =
|
||||
std::vector<std::pair<c10::SourceLocation, std::string>>;
|
||||
@ -278,6 +278,34 @@ private:
|
||||
warning_buffer_t warning_buffer_;
|
||||
|
||||
at::WarningHandler* prev_handler_;
|
||||
bool in_exception_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <typename Func, size_t i>
|
||||
using Arg = typename function_traits<Func>::template arg<i>::type;
|
||||
|
||||
template <typename Func, size_t ...Is>
|
||||
auto wrap_pybind_function_impl_(Func&& f, std::index_sequence<Is...>) {
|
||||
using traits = function_traits<Func>;
|
||||
namespace py = pybind11;
|
||||
|
||||
// f=f is needed to handle function references on older compilers
|
||||
return [f=f](Arg<Func, Is> ...args) -> typename traits::result_type {
|
||||
HANDLE_TH_ERRORS
|
||||
return f(std::forward<Arg<Func, Is>>(args)...);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
};
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Wrap a function with TH error and warning handling.
|
||||
// Returns a function object suitable for registering with pybind11.
|
||||
template <typename Func>
|
||||
auto wrap_pybind_function(Func&& f) {
|
||||
using traits = function_traits<Func>;
|
||||
return torch::detail::wrap_pybind_function_impl_(
|
||||
std::forward<Func>(f), std::make_index_sequence<traits::arity>{});
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
@ -742,6 +742,16 @@ PyObject* initModule() {
|
||||
// setting up TH Errors so that they throw C++ exceptions
|
||||
at::init();
|
||||
|
||||
// Automatically translate errors thrown from pybind11 functions
|
||||
py::register_exception_translator([](std::exception_ptr e) { // NOLINT
|
||||
try {
|
||||
if (e) {
|
||||
std::rethrow_exception(e);
|
||||
}
|
||||
}
|
||||
CATCH_TH_ERRORS()
|
||||
});
|
||||
|
||||
auto py_module = py::reinterpret_borrow<py::module>(module);
|
||||
py_module.def("_demangle", &c10::demangle);
|
||||
py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);
|
||||
|
@ -111,7 +111,12 @@ PyObject* dist_autograd_init(PyObject* /* unused */) {
|
||||
for (const auto& root : roots) {
|
||||
variables.emplace_back(root);
|
||||
}
|
||||
DistEngine::getInstance().execute(variables, retainGraph);
|
||||
try {
|
||||
DistEngine::getInstance().execute(variables, retainGraph);
|
||||
} catch (python_error & e) {
|
||||
// FIXME: crashes if exception type is not RuntimeError
|
||||
throw std::runtime_error(e.what());
|
||||
}
|
||||
},
|
||||
R"(
|
||||
backward(roots: List[Tensor], retain_graph = False) -> None
|
||||
|
@ -101,26 +101,6 @@ COMMON_NVCC_FLAGS = [
|
||||
'--expt-relaxed-constexpr'
|
||||
]
|
||||
|
||||
# See comment in load_inline for more information
|
||||
# The goal is to be able to call the safe version of the
|
||||
# function exactly as if it was the original one.
|
||||
# We need to create a pointer to this new function to give
|
||||
# it to pybind later.
|
||||
|
||||
SAFE_FUNCTION_DEFINITION = '''
|
||||
#include <functional>
|
||||
|
||||
template <typename Ret, typename ...Args>
|
||||
auto _get_safe_version(Ret (*f)(Args...)) -> std::function<Ret(Args...)> {{
|
||||
return [f](Args&& ...args) -> Ret {{
|
||||
HANDLE_TH_ERRORS
|
||||
return f(std::forward<Args>(args)...);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}};
|
||||
}}
|
||||
|
||||
'''
|
||||
|
||||
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
||||
|
||||
|
||||
@ -945,11 +925,6 @@ def load_inline(name,
|
||||
|
||||
cpp_sources.insert(0, '#include <torch/extension.h>')
|
||||
|
||||
# Adds a new `_get_safe_version(foo)` function that returns a new function
|
||||
# that performs the same operation as `foo` but with pytorch error handling
|
||||
# macros.
|
||||
cpp_sources.append(SAFE_FUNCTION_DEFINITION)
|
||||
|
||||
# If `functions` is supplied, we create the pybind11 bindings for the user.
|
||||
# Here, `functions` is (or becomes, after some processing) a map from
|
||||
# function names to function docstrings.
|
||||
@ -967,8 +942,9 @@ def load_inline(name,
|
||||
type(functions)))
|
||||
for function_name, docstring in functions.items():
|
||||
if with_pytorch_error_handling:
|
||||
module_def.append('m.def("{0}", _get_safe_version({0}), "{1}");'.format(
|
||||
function_name, docstring))
|
||||
module_def.append(
|
||||
'm.def("{0}", torch::wrap_pybind_function({0}), "{1}");'
|
||||
.format(function_name, docstring))
|
||||
else:
|
||||
module_def.append('m.def("{0}", {0}, "{1}");'.format(function_name, docstring))
|
||||
module_def.append('}')
|
||||
|
Reference in New Issue
Block a user