[JIT/Futures] support set_exception api (#50983)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50983

There is currently no way to handle/propagate errors with the python-based futures API (they are raised correctly if set with an error, but this is only possible from C++).

This diff allows the Future's `unwrap_func` to be set in python optionally, so users can set futures completed with an exception and the error will throw as expected. This is mostly to support the following use case in the next diff:

```
ret_fut = torch.futures.Future(unwrap_func = lambda python_result: {
    # throw exception if needed
    if isinstance(python_result, Exception):
        throw python_result
})

rpc_fut = rpc.rpc_async(...) # RPC future that times out
# Goal is to propagate RPC error to this future
rpc_fut.add_done_callback(
res => {
    # Note that ret_fut.set_result(res.wait()) won't propagate the error
    try:
        ret_fut.set_result(res.wait())
    except Exception as e:
        ret_fut.set_result(e)
}
)
```
ghstack-source-id: 121021434

Test Plan:
unittest
```
buck test mode/dev-nosan mode/no-gpu //caffe2/test:futures -- te
st_unwrap --print-passing-details
```

Reviewed By: mrshenli

Differential Revision: D25950304

fbshipit-source-id: 7ee61e98fcd783b3f515706fa141d538e6d2174d
This commit is contained in:
Rohan Varma
2021-02-04 20:15:25 -08:00
committed by Facebook GitHub Bot
parent 8e78dd6de8
commit c941730b96
5 changed files with 127 additions and 16 deletions

View File

@ -4,6 +4,9 @@ import torch
import unittest import unittest
from torch.futures import Future from torch.futures import Future
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
from typing import TypeVar
T = TypeVar("T")
def add_one(fut): def add_one(fut):
@ -11,6 +14,63 @@ def add_one(fut):
class TestFuture(TestCase): class TestFuture(TestCase):
def test_set_exception(self) -> None:
# This test is to ensure errors can propagate across futures.
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
f = Future[T]()
# Set exception
f.set_exception(value_error) # type: ignore
# Exception should throw on wait
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
# Exception should also throw on value
f = Future()
f.set_exception(value_error) # type: ignore
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value() # type: ignore
def cb(fut):
fut.value() # type: ignore
f = Future()
f.set_exception(value_error) # type: ignore
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
cb_fut = f.then(cb)
cb_fut.wait()
def test_set_exception_multithreading(self) -> None:
# Ensure errors can propagate when one thread waits on future result
# and the other sets it with an error.
error_msg = "Intentional Value Error"
value_error = ValueError(error_msg)
def wait_future(f):
with self.assertRaisesRegex(ValueError, "Intentional"):
f.wait()
f = Future[T]()
t = threading.Thread(target=wait_future, args=(f, ))
t.start()
f.set_exception(value_error) # type: ignore
t.join()
def cb(fut):
fut.value() # type: ignore
def then_future(f):
fut = f.then(cb)
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
fut.wait()
f = Future[T]()
t = threading.Thread(target=then_future, args=(f, ))
t.start()
f.set_exception(value_error) # type: ignore
t.join()
def test_done(self) -> None: def test_done(self) -> None:
f = Future[torch.Tensor]() f = Future[torch.Tensor]()

View File

@ -150,6 +150,7 @@ class Future(object):
def add_done_callback(self, callback: Callable) -> None: ... def add_done_callback(self, callback: Callable) -> None: ...
def then(self, callback: Callable) -> Future: ... def then(self, callback: Callable) -> Future: ...
def set_result(self, result: Any) -> None: ... def set_result(self, result: Any) -> None: ...
def _set_unwrap_func(self, callback: Callable) -> None: ...
def _jit_set_num_profiled_runs(num: _size) -> _size: ... def _jit_set_num_profiled_runs(num: _size) -> _size: ...

View File

@ -1183,6 +1183,21 @@ void initJITBindings(PyObject* module) {
"set_result", "set_result",
// Intentionally not releasing GIL // Intentionally not releasing GIL
&PythonFutureWrapper::markCompleted) &PythonFutureWrapper::markCompleted)
.def(
"_set_unwrap_func",
// Intentionally not releasing GIL as this just does an assign
[](PythonFutureWrapper& self, py::function unwrapFunc) {
auto functionGuard =
std::make_shared<torch::jit::PythonFunctionGuard>(
std::move(unwrapFunc));
std::function<void(py::object)> pf =
[functionGuard(std::move(functionGuard))](
const py::object& inp) {
return functionGuard->func_(inp);
};
self.unwrap_func = std::move(pf);
})
.def( .def(
py::pickle( py::pickle(
/* __getstate__ */ /* __getstate__ */

View File

@ -63,6 +63,25 @@ IValue toIValue(
py::object toPyObject(IValue ivalue); py::object toPyObject(IValue ivalue);
// Wrap Python function to guard deref
// NB: Need VISIBILITY_HIDDEN for silencing compiler error,
// 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
// type of its field 'torch::jit::PythonFunctionGuard::func_'
struct VISIBILITY_HIDDEN PythonFunctionGuard {
explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
~PythonFunctionGuard() {
pybind11::gil_scoped_acquire ag;
func_.dec_ref();
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
// decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
func_.ptr() = nullptr;
}
py::function func_;
};
// The PythonFutureWrapper for ivalue::Future // The PythonFutureWrapper for ivalue::Future
// //
// NB: VISIBILITY_HIDDEN is for silencing compiling error, // NB: VISIBILITY_HIDDEN is for silencing compiling error,
@ -93,6 +112,9 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
// without grabbing the GIL. // without grabbing the GIL.
py::gil_scoped_acquire acquire; py::gil_scoped_acquire acquire;
py::object py_obj = toPyObject(fut->value()); py::object py_obj = toPyObject(fut->value());
// unwrap_func is a general compositional function that takes in a
// py::object and executes some python function. It is currently mostly used
// to throw python exceptions.
if (unwrap_func) { if (unwrap_func) {
(*unwrap_func)(py_obj); (*unwrap_func)(py_obj);
} }
@ -197,22 +219,6 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
c10::optional<UnwrapFunc> unwrap_func; c10::optional<UnwrapFunc> unwrap_func;
private: private:
// Wrap Python function to guard deref
struct PythonFunctionGuard {
explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
~PythonFunctionGuard() {
pybind11::gil_scoped_acquire ag;
func_.dec_ref();
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
// decref on the PyObject again.
// See Note [Destructing py::object] in python_ivalue.h
func_.ptr() = nullptr;
}
py::function func_;
};
std::shared_ptr<PythonFutureWrapper> getPtr() { std::shared_ptr<PythonFutureWrapper> getPtr() {
return shared_from_this(); return shared_from_this();
} }

View File

@ -150,6 +150,35 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
""" """
super().set_result(result) super().set_result(result)
def set_exception(self, result: T) -> None:
r"""
Set an exception for this ``Future``, which will mark this ``Future`` as
completed with an error and trigger all attached callbacks. Note that
when calling wait()/value() on this ``Future``, the exception set here
will be raised inline.
Args:
result (BaseException): the exception for this ``Future``.
Example::
>>> import torch
>>>
>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
>>>
>>> # Output:
>>> # This will run after the future has finished.
>>> ValueError: foo
"""
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
def raise_error(fut_result):
raise fut_result
super()._set_unwrap_func(raise_error)
self.set_result(result) # type: ignore
def collect_all(futures: List[Future]) -> Future[List[Future]]: def collect_all(futures: List[Future]) -> Future[List[Future]]:
r""" r"""