mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add pybind version of HANDLE_TH_ERRORS
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26614 Test Plan: Imported from OSS Differential Revision: D18249634 Pulled By: albanD fbshipit-source-id: 25503f368926e0f3633c5af0f222c9bb4729f342
This commit is contained in:
committed by
Facebook Github Bot
parent
9b875e1256
commit
0ff1696c75
@ -727,6 +727,111 @@ class TestCppExtension(common.TestCase):
|
||||
pattern = r'.*(\\n|\\r).*'
|
||||
self.assertNotRegex(str(e), pattern)
|
||||
|
||||
def test_warning(self):
|
||||
# Note: the module created from this source will include the py::key_error
|
||||
# symbol. But because of visibility and the fact that it lives in a
|
||||
# different compilation unit than pybind, this trips up ubsan even though
|
||||
# it is fine. "ubsan.supp" thus needs to contain "vptr:warn_mod.so".
|
||||
source = '''
|
||||
// error_type:
|
||||
// 0: no error
|
||||
// 1: torch::TypeError
|
||||
// 2: python_error()
|
||||
// 3: py::error_already_set
|
||||
at::Tensor foo(at::Tensor x, int error_type) {
|
||||
std::ostringstream err_stream;
|
||||
err_stream << "Error with " << x.type();
|
||||
|
||||
TORCH_WARN(err_stream.str());
|
||||
if(error_type == 1) {
|
||||
throw torch::TypeError(err_stream.str().c_str());
|
||||
}
|
||||
if(error_type == 2) {
|
||||
PyObject* obj = PyTuple_New(-1);
|
||||
TORCH_CHECK(!obj);
|
||||
// Pretend it was caught in a different thread and restored here
|
||||
auto e = python_error();
|
||||
e.persist();
|
||||
e.restore();
|
||||
throw e;
|
||||
}
|
||||
if(error_type == 3) {
|
||||
throw py::key_error(err_stream.str());
|
||||
}
|
||||
return x.cos();
|
||||
}
|
||||
'''
|
||||
|
||||
# Ensure double type for hard-coded c name below
|
||||
t = torch.rand(2).double()
|
||||
cpp_tensor_name = r"CPUDoubleType"
|
||||
|
||||
# Without error handling, the warnings cannot be catched
|
||||
# and the Tensor type names are not cleaned
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
|
||||
cpp_sources=[source],
|
||||
functions=['foo'],
|
||||
with_pytorch_error_handling=False)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warn_mod.foo(t, 0)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# pybind translate all our errors to RuntimeError
|
||||
with self.assertRaisesRegex(RuntimeError, cpp_tensor_name):
|
||||
warn_mod.foo(t, 1)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "bad argument to internal function|python_error"):
|
||||
warn_mod.foo(t, 2)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
with self.assertRaisesRegex(KeyError, cpp_tensor_name):
|
||||
warn_mod.foo(t, 3)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
||||
warn_mod = torch.utils.cpp_extension.load_inline(name='warn_mod',
|
||||
cpp_sources=[source],
|
||||
functions=['foo'],
|
||||
with_pytorch_error_handling=True)
|
||||
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# Catched with no error should be detected
|
||||
warn_mod.foo(t, 0)
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# Catched with cpp error should not be detected
|
||||
with self.assertRaisesRegex(TypeError, t.type()):
|
||||
warn_mod.foo(t, 1)
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# Catched with python error should not be detected
|
||||
with self.assertRaisesRegex(SystemError, "bad argument to internal function"):
|
||||
warn_mod.foo(t, 2)
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# Catched with pybind error should not be detected
|
||||
# Note that there is no type name translation for pybind errors
|
||||
with self.assertRaisesRegex(KeyError, cpp_tensor_name):
|
||||
warn_mod.foo(t, 3)
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# Make sure raising warnings are handled properly
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("error")
|
||||
|
||||
# No error, the warning should raise
|
||||
with self.assertRaisesRegex(UserWarning, t.type()):
|
||||
warn_mod.foo(t, 0)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# Another error happened, the warning is ignored
|
||||
with self.assertRaisesRegex(TypeError, t.type()):
|
||||
warn_mod.foo(t, 1)
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
||||
class TestMSNPUTensor(common.TestCase):
|
||||
@classmethod
|
||||
|
@ -5680,7 +5680,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
self.assertEqual(e1, e2)
|
||||
|
||||
def test_batch_norm_cpu_inference(self):
|
||||
# input nchw in (2,1,1,1), (2,2,2,2)
|
||||
# input nchw in (2,1,1,1), (2,2,2,2)
|
||||
inputs = [
|
||||
torch.tensor([[[[-0.5000]]], [[[0.5000]]]]),
|
||||
torch.tensor([
|
||||
@ -5692,7 +5692,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
[[0.1000, 1.0000], [1.0000, 0.1000]],
|
||||
[[1.0000, 0.5000], [1.5000, -1.5000]]
|
||||
]])]
|
||||
# output nchw in (2,1,1,1), (2,2,2,2)
|
||||
# output nchw in (2,1,1,1), (2,2,2,2)
|
||||
outputs = [
|
||||
torch.tensor([
|
||||
[[[-0.499997496604919433593750000]]],
|
||||
@ -5718,9 +5718,9 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
output2 = m(input2).permute(0, 1, 3, 2)
|
||||
# channels last case
|
||||
input3 = input1.contiguous(memory_format=torch.channels_last)
|
||||
for name, param in m.named_parameters():
|
||||
if param.requires_grad:
|
||||
if param.data.dim() == 4:
|
||||
for name, param in m.named_parameters():
|
||||
if param.requires_grad:
|
||||
if param.data.dim() == 4:
|
||||
param.data = param.data.contiguous(memory_format=torch.channels_last)
|
||||
output3 = m(input3)
|
||||
self.assertEqual(output3, outputs[i])
|
||||
@ -14292,7 +14292,6 @@ def generate_not_implemented_tests(cls):
|
||||
class TestTensorDeviceOps(TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class TestTorch(TestCase, _TestTorchMixin):
|
||||
pass
|
||||
|
||||
|
@ -41,36 +41,60 @@
|
||||
torch::PyWarningHandler __enforce_warning_buffer; \
|
||||
try{
|
||||
|
||||
#define CATCH_TH_ERRORS(retval) \
|
||||
#define CATCH_TH_ERRORS(retstmnt) \
|
||||
catch (python_error & e) { \
|
||||
return retval; \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10::IndexError& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
|
||||
PyErr_SetString(PyExc_IndexError, msg.c_str()); \
|
||||
return retval; \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const c10::Error& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
|
||||
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
|
||||
return retval; \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (torch::PyTorchError & e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(e.python_type(), msg.c_str()); \
|
||||
return retval; \
|
||||
retstmnt; \
|
||||
} \
|
||||
catch (const std::exception& e) { \
|
||||
auto msg = torch::processErrorMsg(e.what()); \
|
||||
PyErr_SetString(PyExc_RuntimeError, msg.c_str()); \
|
||||
return retval; \
|
||||
retstmnt; \
|
||||
}
|
||||
|
||||
#define END_HANDLE_TH_ERRORS_PYBIND \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
/* Unpack already stored error to be detectable by warning code */ \
|
||||
e.restore(); \
|
||||
throw; \
|
||||
} \
|
||||
catch (py::builtin_exception & e) { \
|
||||
/* Unpack already stored error to be detectable by warning code */ \
|
||||
e.set_error(); \
|
||||
throw; \
|
||||
} \
|
||||
CATCH_TH_ERRORS(throw) \
|
||||
} \
|
||||
catch (py::error_already_set & e) { \
|
||||
/* Repack already stored error */ \
|
||||
throw py::error_already_set(); \
|
||||
} \
|
||||
catch (py::builtin_exception & e) { \
|
||||
/* Repack already stored error */ \
|
||||
throw py::error_already_set(); \
|
||||
} \
|
||||
CATCH_TH_ERRORS(throw py::error_already_set())
|
||||
|
||||
#define END_HANDLE_TH_ERRORS_RET(retval) \
|
||||
} \
|
||||
CATCH_TH_ERRORS(retval) \
|
||||
CATCH_TH_ERRORS(return retval) \
|
||||
} \
|
||||
CATCH_TH_ERRORS(retval)
|
||||
CATCH_TH_ERRORS(return retval)
|
||||
|
||||
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
|
||||
|
||||
|
@ -48,15 +48,22 @@ Similarly, if we raise a C++ exception, prior to returning to the Python
|
||||
interpreter, we must set the Python error flags, so it turns into a C++
|
||||
exception.
|
||||
|
||||
Exceptions defines some useful helpers: `HANDLE_TH_ERRORS`, `END_HANDLE_TH_ERRORS`
|
||||
and an exception class `python_error`. You call them like this:
|
||||
Moreover, when using the following macros, the generated warnings
|
||||
will be converted into python warnings that can be caught by the user.
|
||||
|
||||
Exceptions define helpers for two main cases:
|
||||
* For code where you write the python binding by hand, `HANDLE_TH_ERRORS`,
|
||||
`END_HANDLE_TH_ERRORS` and an exception class `python_error`. You call them like this:
|
||||
|
||||
```
|
||||
// Entry point from Python interpreter
|
||||
PyObject* run() {
|
||||
PyObject* run(PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
...
|
||||
if (!x) throw python_error();
|
||||
// From c10/Exception.h
|
||||
TORCH_CHECK(cond, "cond was false here");
|
||||
TORCH_WARN("Warning message");
|
||||
...
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -68,6 +75,27 @@ exception which doesn't contain any info, instead it says, "An error
|
||||
occurred in the Python API; if you return to the interpreter, Python
|
||||
will raise that exception, nothing else needs to be done."
|
||||
|
||||
* For code that you bind using pybind, `HANDLE_TH_ERRORS` and `END_HANDLE_TH_ERRORS_PYBIND`
|
||||
can be used. They will work jointly with pybind error handling to raise
|
||||
pytorch errors and warnings natively and let pybind handle other errors. It can be used as:
|
||||
|
||||
```
|
||||
// Function given to the pybind binding
|
||||
at::Tensor foo(at::Tensor x) {
|
||||
HANDLE_TH_ERRORS
|
||||
...
|
||||
if (!x) throw python_error();
|
||||
// pybind native error
|
||||
if (!x) throw py::value_error();
|
||||
// From c10/Exception.h
|
||||
TORCH_CHECK(cond, "cond was false here");
|
||||
TORCH_WARN("Warning message");
|
||||
...
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### `utils/auto_gil.h`
|
||||
|
||||
Whenever you make any calls to the Python API, you must have taken out
|
||||
|
@ -102,6 +102,25 @@ COMMON_NVCC_FLAGS = [
|
||||
'--expt-relaxed-constexpr'
|
||||
]
|
||||
|
||||
# See comment in load_inline for more information
|
||||
# The goal is to be able to call the safe version of the
|
||||
# function exactely as if it was the original one.
|
||||
# We need to create a pointer to this new function to give
|
||||
# it to pybind later.
|
||||
|
||||
SAFE_FUNCTION_DEFINITION = '''
|
||||
#include <functional>
|
||||
|
||||
template <typename Ret, typename ...Args>
|
||||
auto _get_safe_version(Ret (*f)(Args...)) -> std::function<Ret(Args...)> {{
|
||||
return [f](Args&& ...args) -> Ret {{
|
||||
HANDLE_TH_ERRORS
|
||||
return f(std::forward<Args>(args)...);
|
||||
END_HANDLE_TH_ERRORS_PYBIND
|
||||
}};
|
||||
}}
|
||||
|
||||
'''
|
||||
|
||||
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
||||
|
||||
@ -672,7 +691,8 @@ def load_inline(name,
|
||||
build_directory=None,
|
||||
verbose=False,
|
||||
with_cuda=None,
|
||||
is_python_module=True):
|
||||
is_python_module=True,
|
||||
with_pytorch_error_handling=True):
|
||||
'''
|
||||
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
|
||||
|
||||
@ -717,8 +737,14 @@ def load_inline(name,
|
||||
with_cuda: Determines whether CUDA headers and libraries are added to
|
||||
the build. If set to ``None`` (default), this value is
|
||||
automatically determined based on whether ``cuda_sources`` is
|
||||
provided. Set it to `True`` to force CUDA headers
|
||||
provided. Set it to ``True`` to force CUDA headers
|
||||
and libraries to be included.
|
||||
with_pytorch_error_handling: Determines whether pytorch error and
|
||||
warning macros are handled by pytorch instead of pybind. To do
|
||||
this, each function ``foo`` is called via an intermediary ``_safe_foo``
|
||||
function. This redirection might cause issues in obscure cases
|
||||
of cpp. This flag should be set to ``False`` when this redirect
|
||||
causes issues.
|
||||
|
||||
Example:
|
||||
>>> from torch.utils.cpp_extension import load_inline
|
||||
@ -741,11 +767,17 @@ def load_inline(name,
|
||||
|
||||
cpp_sources.insert(0, '#include <torch/extension.h>')
|
||||
|
||||
# Adds a new `_get_safe_version(foo)` function that returns a new function
|
||||
# that performs the same operation as `foo` but with pytorch error handling
|
||||
# macros.
|
||||
cpp_sources.append(SAFE_FUNCTION_DEFINITION)
|
||||
|
||||
# If `functions` is supplied, we create the pybind11 bindings for the user.
|
||||
# Here, `functions` is (or becomes, after some processing) a map from
|
||||
# function names to function docstrings.
|
||||
if functions is not None:
|
||||
cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
|
||||
module_def = []
|
||||
module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
|
||||
if isinstance(functions, str):
|
||||
functions = [functions]
|
||||
if isinstance(functions, list):
|
||||
@ -756,9 +788,13 @@ def load_inline(name,
|
||||
"Expected 'functions' to be a list or dict, but was {}".format(
|
||||
type(functions)))
|
||||
for function_name, docstring in functions.items():
|
||||
cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
|
||||
function_name, docstring))
|
||||
cpp_sources.append('}')
|
||||
if with_pytorch_error_handling:
|
||||
module_def.append('m.def("{0}", _get_safe_version({0}), "{1}");'.format(
|
||||
function_name, docstring))
|
||||
else:
|
||||
module_def.append('m.def("{0}", {0}, "{1}");'.format(function_name, docstring))
|
||||
module_def.append('}')
|
||||
cpp_sources += module_def
|
||||
|
||||
cpp_source_path = os.path.join(build_directory, 'main.cpp')
|
||||
with open(cpp_source_path, 'w') as cpp_source_file:
|
||||
|
@ -1,2 +1,3 @@
|
||||
vptr:libtorch.so
|
||||
vptr:libtorch_python.so
|
||||
vptr:libcaffe2.so
|
||||
|
Reference in New Issue
Block a user