mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT/Futures] support set_exception api (#50983)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50983 There is currently no way to handle/propagate errors with the python-based futures API (they are raised correctly if set with an error, but this is only possible from C++). This diff allows the Future's `unwrap_func` to be set in python optionally, so users can set futures completed with an exception and the error will throw as expected. This is mostly to support the following use case in the next diff: ``` ret_fut = torch.futures.Future(unwrap_func = lambda python_result: { # throw exception if needed if isinstance(python_result, Exception): throw python_result }) rpc_fut = rpc.rpc_async(...) # RPC future that times out # Goal is to propagate RPC error to this future rpc_fut.add_done_callback( res => { # Note that ret_fut.set_result(res.wait()) won't propagate the error try: ret_fut.set_result(res.wait()) except Exception as e: ret_fut.set_result(e) } ) ``` ghstack-source-id: 121021434 Test Plan: unittest ``` buck test mode/dev-nosan mode/no-gpu //caffe2/test:futures -- te st_unwrap --print-passing-details ``` Reviewed By: mrshenli Differential Revision: D25950304 fbshipit-source-id: 7ee61e98fcd783b3f515706fa141d538e6d2174d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8e78dd6de8
commit
c941730b96
@ -4,6 +4,9 @@ import torch
|
||||
import unittest
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def add_one(fut):
|
||||
@ -11,6 +14,63 @@ def add_one(fut):
|
||||
|
||||
|
||||
class TestFuture(TestCase):
|
||||
def test_set_exception(self) -> None:
|
||||
# This test is to ensure errors can propagate across futures.
|
||||
error_msg = "Intentional Value Error"
|
||||
value_error = ValueError(error_msg)
|
||||
|
||||
f = Future[T]()
|
||||
# Set exception
|
||||
f.set_exception(value_error) # type: ignore
|
||||
# Exception should throw on wait
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.wait()
|
||||
|
||||
# Exception should also throw on value
|
||||
f = Future()
|
||||
f.set_exception(value_error) # type: ignore
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.value() # type: ignore
|
||||
|
||||
def cb(fut):
|
||||
fut.value() # type: ignore
|
||||
|
||||
f = Future()
|
||||
f.set_exception(value_error) # type: ignore
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
cb_fut = f.then(cb)
|
||||
cb_fut.wait()
|
||||
|
||||
def test_set_exception_multithreading(self) -> None:
|
||||
# Ensure errors can propagate when one thread waits on future result
|
||||
# and the other sets it with an error.
|
||||
error_msg = "Intentional Value Error"
|
||||
value_error = ValueError(error_msg)
|
||||
|
||||
def wait_future(f):
|
||||
with self.assertRaisesRegex(ValueError, "Intentional"):
|
||||
f.wait()
|
||||
|
||||
f = Future[T]()
|
||||
t = threading.Thread(target=wait_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error) # type: ignore
|
||||
t.join()
|
||||
|
||||
def cb(fut):
|
||||
fut.value() # type: ignore
|
||||
|
||||
def then_future(f):
|
||||
fut = f.then(cb)
|
||||
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
||||
fut.wait()
|
||||
|
||||
f = Future[T]()
|
||||
t = threading.Thread(target=then_future, args=(f, ))
|
||||
t.start()
|
||||
f.set_exception(value_error) # type: ignore
|
||||
t.join()
|
||||
|
||||
def test_done(self) -> None:
|
||||
f = Future[torch.Tensor]()
|
||||
|
@ -150,6 +150,7 @@ class Future(object):
|
||||
def add_done_callback(self, callback: Callable) -> None: ...
|
||||
def then(self, callback: Callable) -> Future: ...
|
||||
def set_result(self, result: Any) -> None: ...
|
||||
def _set_unwrap_func(self, callback: Callable) -> None: ...
|
||||
|
||||
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
||||
|
||||
|
@ -1183,6 +1183,21 @@ void initJITBindings(PyObject* module) {
|
||||
"set_result",
|
||||
// Intentionally not releasing GIL
|
||||
&PythonFutureWrapper::markCompleted)
|
||||
.def(
|
||||
"_set_unwrap_func",
|
||||
// Intentionally not releasing GIL as this just does an assign
|
||||
[](PythonFutureWrapper& self, py::function unwrapFunc) {
|
||||
auto functionGuard =
|
||||
std::make_shared<torch::jit::PythonFunctionGuard>(
|
||||
std::move(unwrapFunc));
|
||||
|
||||
std::function<void(py::object)> pf =
|
||||
[functionGuard(std::move(functionGuard))](
|
||||
const py::object& inp) {
|
||||
return functionGuard->func_(inp);
|
||||
};
|
||||
self.unwrap_func = std::move(pf);
|
||||
})
|
||||
.def(
|
||||
py::pickle(
|
||||
/* __getstate__ */
|
||||
|
@ -63,6 +63,25 @@ IValue toIValue(
|
||||
|
||||
py::object toPyObject(IValue ivalue);
|
||||
|
||||
// Wrap Python function to guard deref
|
||||
// NB: Need VISIBILITY_HIDDEN for silencing compiler error,
|
||||
// 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
|
||||
// type of its field 'torch::jit::PythonFunctionGuard::func_'
|
||||
struct VISIBILITY_HIDDEN PythonFunctionGuard {
|
||||
explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
|
||||
|
||||
~PythonFunctionGuard() {
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
func_.dec_ref();
|
||||
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
|
||||
// decref on the PyObject again.
|
||||
// See Note [Destructing py::object] in python_ivalue.h
|
||||
func_.ptr() = nullptr;
|
||||
}
|
||||
|
||||
py::function func_;
|
||||
};
|
||||
|
||||
// The PythonFutureWrapper for ivalue::Future
|
||||
//
|
||||
// NB: VISIBILITY_HIDDEN is for silencing compiling error,
|
||||
@ -93,6 +112,9 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
|
||||
// without grabbing the GIL.
|
||||
py::gil_scoped_acquire acquire;
|
||||
py::object py_obj = toPyObject(fut->value());
|
||||
// unwrap_func is a general compositional function that takes in a
|
||||
// py::object and executes some python function. It is currently mostly used
|
||||
// to throw python exceptions.
|
||||
if (unwrap_func) {
|
||||
(*unwrap_func)(py_obj);
|
||||
}
|
||||
@ -197,22 +219,6 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
|
||||
c10::optional<UnwrapFunc> unwrap_func;
|
||||
|
||||
private:
|
||||
// Wrap Python function to guard deref
|
||||
struct PythonFunctionGuard {
|
||||
explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
|
||||
|
||||
~PythonFunctionGuard() {
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
func_.dec_ref();
|
||||
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
|
||||
// decref on the PyObject again.
|
||||
// See Note [Destructing py::object] in python_ivalue.h
|
||||
func_.ptr() = nullptr;
|
||||
}
|
||||
|
||||
py::function func_;
|
||||
};
|
||||
|
||||
std::shared_ptr<PythonFutureWrapper> getPtr() {
|
||||
return shared_from_this();
|
||||
}
|
||||
|
@ -150,6 +150,35 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
"""
|
||||
super().set_result(result)
|
||||
|
||||
def set_exception(self, result: T) -> None:
|
||||
r"""
|
||||
Set an exception for this ``Future``, which will mark this ``Future`` as
|
||||
completed with an error and trigger all attached callbacks. Note that
|
||||
when calling wait()/value() on this ``Future``, the exception set here
|
||||
will be raised inline.
|
||||
|
||||
Args:
|
||||
result (BaseException): the exception for this ``Future``.
|
||||
|
||||
Example::
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> fut = torch.futures.Future()
|
||||
>>> fut.set_exception(ValueError("foo"))
|
||||
>>> fut.wait()
|
||||
>>>
|
||||
>>> # Output:
|
||||
>>> # This will run after the future has finished.
|
||||
>>> ValueError: foo
|
||||
"""
|
||||
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
|
||||
|
||||
def raise_error(fut_result):
|
||||
raise fut_result
|
||||
|
||||
super()._set_unwrap_func(raise_error)
|
||||
self.set_result(result) # type: ignore
|
||||
|
||||
|
||||
def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||||
r"""
|
||||
|
Reference in New Issue
Block a user