add pybind version of HANDLE_TH_ERRORS

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26614

Test Plan: Imported from OSS

Differential Revision: D18249634

Pulled By: albanD

fbshipit-source-id: 25503f368926e0f3633c5af0f222c9bb4729f342
This commit is contained in:
Alban Desmaison
2019-11-07 08:32:51 -08:00
committed by Facebook Github Bot
parent 9b875e1256
commit 0ff1696c75
6 changed files with 216 additions and 23 deletions

View File

@ -727,6 +727,111 @@ class TestCppExtension(common.TestCase):
pattern = r'.*(\\n|\\r).*'
self.assertNotRegex(str(e), pattern)
def test_warning(self):
# Note: the module created from this source will include the py::key_error
# symbol. But because of visibility and the fact that it lives in a
# different compilation unit than pybind, this trips up ubsan even though
# it is fine. "ubsan.supp" thus needs to contain "vptr:warn_mod.so".
source = '''
// error_type:
// 0: no error
// 1: torch::TypeError
// 2: python_error()
// 3: py::error_already_set
at::Tensor foo(at::Tensor x, int error_type) {
std::ostringstream err_stream;
err_stream << "Error with " << x.type();
TORCH_WARN(err_stream.str());
if(error_type == 1) {
throw torch::TypeError(err_stream.str().c_str());
}
if(error_type == 2) {
PyObject* obj = PyTuple_New(-1);
TORCH_CHECK(!obj);
// Pretend it was caught in a different thread and restored here
auto e = python_error();
e.persist();
e.restore();
throw e;
}
if(error_type == 3) {
throw py::key_error(err_stream.str());
}
return x.cos();
}
'''
# Ensure double type for hard-coded c name below
t = torch.rand(2).double()
cpp_tensor_name = r"CPUDoubleType"
# Without error handling, the warnings cannot be catched
# and the Tensor type names are not cleaned
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
cpp_sources=[source],
functions=['foo'],
with_pytorch_error_handling=False)
with warnings.catch_warnings(record=True) as w:
warn_mod.foo(t, 0)
self.assertEqual(len(w), 0)
# pybind translate all our errors to RuntimeError
with self.assertRaisesRegex(RuntimeError, cpp_tensor_name):
warn_mod.foo(t, 1)
self.assertEqual(len(w), 0)
with self.assertRaisesRegex(RuntimeError, "bad argument to internal function|python_error"):
warn_mod.foo(t, 2)
self.assertEqual(len(w), 0)
with self.assertRaisesRegex(KeyError, cpp_tensor_name):
warn_mod.foo(t, 3)
self.assertEqual(len(w), 0)
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
cpp_sources=[source],
functions=['foo'],
with_pytorch_error_handling=True)
with warnings.catch_warnings(record=True) as w:
# Catched with no error should be detected
warn_mod.foo(t, 0)
self.assertEqual(len(w), 1)
# Catched with cpp error should not be detected
with self.assertRaisesRegex(TypeError, t.type()):
warn_mod.foo(t, 1)
self.assertEqual(len(w), 1)
# Catched with python error should not be detected
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
warn_mod.foo(t, 2)
self.assertEqual(len(w), 1)
# Catched with pybind error should not be detected
# Note that there is no type name translation for pybind errors
with self.assertRaisesRegex(KeyError, cpp_tensor_name):
warn_mod.foo(t, 3)
self.assertEqual(len(w), 1)
# Make sure raising warnings are handled properly
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("error")
# No error, the warning should raise
with self.assertRaisesRegex(UserWarning, t.type()):
warn_mod.foo(t, 0)
self.assertEqual(len(w), 0)
# Another error happened, the warning is ignored
with self.assertRaisesRegex(TypeError, t.type()):
warn_mod.foo(t, 1)
self.assertEqual(len(w), 0)
class TestMSNPUTensor(common.TestCase):
@classmethod

View File

@ -5680,7 +5680,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(e1, e2)
def test_batch_norm_cpu_inference(self):
# input nchw in (2,1,1,1), (2,2,2,2)
# input nchw in (2,1,1,1), (2,2,2,2)
inputs = [
torch.tensor([[[[-0.5000]]], [[[0.5000]]]]),
torch.tensor([
@ -5692,7 +5692,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[[0.1000, 1.0000], [1.0000, 0.1000]],
[[1.0000, 0.5000], [1.5000, -1.5000]]
]])]
# output nchw in (2,1,1,1), (2,2,2,2)
# output nchw in (2,1,1,1), (2,2,2,2)
outputs = [
torch.tensor([
[[[-0.499997496604919433593750000]]],
@ -5718,9 +5718,9 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
output2 = m(input2).permute(0, 1, 3, 2)
# channels last case
input3 = input1.contiguous(memory_format=torch.channels_last)
for name, param in m.named_parameters():
if param.requires_grad:
if param.data.dim() == 4:
for name, param in m.named_parameters():
if param.requires_grad:
if param.data.dim() == 4:
param.data = param.data.contiguous(memory_format=torch.channels_last)
output3 = m(input3)
self.assertEqual(output3, outputs[i])
@ -14292,7 +14292,6 @@ def generate_not_implemented_tests(cls):
class TestTensorDeviceOps(TestCase):
pass
class TestTorch(TestCase, _TestTorchMixin):
pass

View File

@ -41,36 +41,60 @@
torch::PyWarningHandler __enforce_warning_buffer; \
try{
#define CATCH_TH_ERRORS(retval) \
#define CATCH_TH_ERRORS(retstmnt) \
catch (python_error & e) { \
return retval; \
retstmnt; \
} \
catch (const c10::IndexError& e) { \
auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
PyErr_SetString(PyExc_IndexError, msg.c_str()); \
return retval; \
retstmnt; \
} \
catch (const c10::Error& e) { \
auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
return retval; \
retstmnt; \
} \
catch (torch::PyTorchError & e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg.c_str()); \
return retval; \
retstmnt; \
} \
catch (const std::exception& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
return retval; \
retstmnt; \
}
#define END_HANDLE_TH_ERRORS_PYBIND \
} \
catch (py::error_already_set & e) { \
/* Unpack already stored error to be detectable by warning code */ \
e.restore(); \
throw; \
} \
catch (py::builtin_exception & e) { \
/* Unpack already stored error to be detectable by warning code */ \
e.set_error(); \
throw; \
} \
CATCH_TH_ERRORS(throw) \
} \
catch (py::error_already_set & e) { \
/* Repack already stored error */ \
throw py::error_already_set(); \
} \
catch (py::builtin_exception & e) { \
/* Repack already stored error */ \
throw py::error_already_set(); \
} \
CATCH_TH_ERRORS(throw py::error_already_set())
#define END_HANDLE_TH_ERRORS_RET(retval) \
} \
CATCH_TH_ERRORS(retval) \
CATCH_TH_ERRORS(return retval) \
} \
CATCH_TH_ERRORS(retval)
CATCH_TH_ERRORS(return retval)
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)

View File

@ -48,15 +48,22 @@ Similarly, if we raise a C++ exception, prior to returning to the Python
interpreter, we must set the Python error flags, so it turns into a C++
exception.
Exceptions defines some useful helpers: `HANDLE_TH_ERRORS`, `END_HANDLE_TH_ERRORS`
and an exception class `python_error`. You call them like this:
Moreover, when using the following macros, the generated warnings
will be converted into python warnings that can be caught by the user.
Exceptions define helpers for two main cases:
* For code where you write the python binding by hand, `HANDLE_TH_ERRORS`,
`END_HANDLE_TH_ERRORS` and an exception class `python_error`. You call them like this:
```
// Entry point from Python interpreter
PyObject* run() {
PyObject* run(PyObject* arg) {
HANDLE_TH_ERRORS
...
if (!x) throw python_error();
// From c10/Exception.h
TORCH_CHECK(cond, "cond was false here");
TORCH_WARN("Warning message");
...
END_HANDLE_TH_ERRORS
}
@ -68,6 +75,27 @@ exception which doesn't contain any info, instead it says, "An error
occurred in the Python API; if you return to the interpreter, Python
will raise that exception, nothing else needs to be done."
* For code that you bind using pybind, `HANDLE_TH_ERRORS` and `END_HANDLE_TH_ERRORS_PYBIND`
can be used. They will work jointly with pybind error handling to raise
pytorch errors and warnings natively and let pybind handle other errors. It can be used as:
```
// Function given to the pybind binding
at::Tensor foo(at::Tensor x) {
HANDLE_TH_ERRORS
...
if (!x) throw python_error();
// pybind native error
if (!x) throw py::value_error();
// From c10/Exception.h
TORCH_CHECK(cond, "cond was false here");
TORCH_WARN("Warning message");
...
END_HANDLE_TH_ERRORS_PYBIND
}
```
### `utils/auto_gil.h`
Whenever you make any calls to the Python API, you must have taken out

View File

@ -102,6 +102,25 @@ COMMON_NVCC_FLAGS = [
'--expt-relaxed-constexpr'
]
# See comment in load_inline for more information
# The goal is to be able to call the safe version of the
# function exactely as if it was the original one.
# We need to create a pointer to this new function to give
# it to pybind later.
SAFE_FUNCTION_DEFINITION = '''
#include <functional>
template <typename Ret, typename ...Args>
auto _get_safe_version(Ret (*f)(Args...)) -> std::function<Ret(Args...)> {{
return [f](Args&& ...args) -> Ret {{
HANDLE_TH_ERRORS
return f(std::forward<Args>(args)...);
END_HANDLE_TH_ERRORS_PYBIND
}};
}}
'''
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
@ -672,7 +691,8 @@ def load_inline(name,
build_directory=None,
verbose=False,
with_cuda=None,
is_python_module=True):
is_python_module=True,
with_pytorch_error_handling=True):
'''
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
@ -717,8 +737,14 @@ def load_inline(name,
with_cuda: Determines whether CUDA headers and libraries are added to
the build. If set to ``None`` (default), this value is
automatically determined based on whether ``cuda_sources`` is
provided. Set it to `True`` to force CUDA headers
provided. Set it to ``True`` to force CUDA headers
and libraries to be included.
with_pytorch_error_handling: Determines whether pytorch error and
warning macros are handled by pytorch instead of pybind. To do
this, each function ``foo`` is called via an intermediary ``_safe_foo``
function. This redirection might cause issues in obscure cases
of cpp. This flag should be set to ``False`` when this redirect
causes issues.
Example:
>>> from torch.utils.cpp_extension import load_inline
@ -741,11 +767,17 @@ def load_inline(name,
cpp_sources.insert(0, '#include <torch/extension.h>')
# Adds a new `_get_safe_version(foo)` function that returns a new function
# that performs the same operation as `foo` but with pytorch error handling
# macros.
cpp_sources.append(SAFE_FUNCTION_DEFINITION)
# If `functions` is supplied, we create the pybind11 bindings for the user.
# Here, `functions` is (or becomes, after some processing) a map from
# function names to function docstrings.
if functions is not None:
cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
module_def = []
module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
if isinstance(functions, str):
functions = [functions]
if isinstance(functions, list):
@ -756,9 +788,13 @@ def load_inline(name,
"Expected 'functions' to be a list or dict, but was {}".format(
type(functions)))
for function_name, docstring in functions.items():
cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
function_name, docstring))
cpp_sources.append('}')
if with_pytorch_error_handling:
module_def.append('m.def("{0}", _get_safe_version({0}), "{1}");'.format(
function_name, docstring))
else:
module_def.append('m.def("{0}", {0}, "{1}");'.format(function_name, docstring))
module_def.append('}')
cpp_sources += module_def
cpp_source_path = os.path.join(build_directory, 'main.cpp')
with open(cpp_source_path, 'w') as cpp_source_file:

View File

@ -1,2 +1,3 @@
vptr:libtorch.so
vptr:libtorch_python.so
vptr:libcaffe2.so