Expose torch.futures.Future (#39008)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39008

This commit adds a `torch.futures.Future` type and exposes its ctor,
`wait`, `then`, and `set_result` APIs. This type is currently a
wrapper of `c10::ivalue::Future` and mainly used by RPC for now. Later,
we could revamp c10d APIs to return this `Future` type as well. More
utils will be added into `torch.futures` package in followup PRs.

Test Plan: Imported from OSS

Differential Revision: D21723022

Pulled By: mrshenli

fbshipit-source-id: 92e56160544e9bf00d11db3e8347a1b9707882c9
This commit is contained in:
Shen Li
2020-06-02 10:09:48 -07:00
committed by Facebook GitHub Bot
parent b3fac8af6b
commit bb0377bb24
10 changed files with 328 additions and 16 deletions

View File

@ -259,7 +259,10 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
*/ */
void markCompleted(IValue value) { void markCompleted(IValue value) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
AT_ASSERT(!completed()); TORCH_CHECK(
!completed(),
"Attempting to mark a completed Future as complete again. Note that "
"a Future can only be marked completed once.");
completed_ = true; completed_ = true;
value_ = std::move(value); value_ = std::move(value);

View File

@ -49,6 +49,9 @@ ignore_errors = True
[mypy-torch.functional.*] [mypy-torch.functional.*]
ignore_errors = True ignore_errors = True
[mypy-torch.futures.*]
ignore_errors = True
[mypy-torch.testing._internal.*] [mypy-torch.testing._internal.*]
ignore_errors = True ignore_errors = True

View File

@ -79,6 +79,7 @@ TESTS = [
'test_determination', 'test_determination',
'distributed/rpc/jit/test_rpc_spawn', 'distributed/rpc/jit/test_rpc_spawn',
'distributed/rpc/faulty_agent/test_rpc_spawn', 'distributed/rpc/faulty_agent/test_rpc_spawn',
'test_futures',
] ]
WINDOWS_BLACKLIST = [ WINDOWS_BLACKLIST = [
@ -154,6 +155,7 @@ SLOW_TESTS = [
'distributed/test_c10d_spawn', 'distributed/test_c10d_spawn',
'test_quantization', 'test_quantization',
'test_determination', 'test_determination',
'test_futures',
] ]
_DEP_MODULES_CACHE = {} _DEP_MODULES_CACHE = {}

98
test/test_futures.py Normal file
View File

@ -0,0 +1,98 @@
import threading
import time
import torch
from torch.futures import Future
from torch.testing._internal.common_utils import TestCase, TemporaryFileName
def add_one(fut):
return fut.wait() + 1
class TestFuture(TestCase):
def test_wait(self):
f = Future()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self):
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self):
fut = Future()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
def test_pickle_future(self):
fut = Future()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)
fut.set_result(torch.ones(2, 2))
for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future()
then_fut = fut.then(cb)
fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()
def test_then_wrong_arg(self):
def wrong_arg(tensor):
return tensor + 1
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
def test_then_no_arg(self):
def no_arg():
return True
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
def test_then_raise(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_error(raise_value_error, "Expected error")

View File

@ -412,6 +412,7 @@ del ComplexFloatStorageBase
import torch.cuda import torch.cuda
import torch.autograd import torch.autograd
from torch.autograd import no_grad, enable_grad, set_grad_enabled from torch.autograd import no_grad, enable_grad, set_grad_enabled
import torch.futures
import torch.nn import torch.nn
import torch.nn.intrinsic import torch.nn.intrinsic
import torch.nn.quantized import torch.nn.quantized

View File

@ -833,13 +833,40 @@ void initJITBindings(PyObject* module) {
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>( py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
m, "Future") m, "Future")
.def(py::init([]() {
return std::make_shared<PythonFutureWrapper>(
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
}))
.def( .def(
"wait", "wait",
&PythonFutureWrapper::wait, &PythonFutureWrapper::wait,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"_then", "then",
&PythonFutureWrapper::then, &PythonFutureWrapper::then,
py::call_guard<py::gil_scoped_release>())
.def(
"set_result",
// Intentionally not releasing GIL
&PythonFutureWrapper::markCompleted)
.def(
py::pickle(
/* __getstate__ */
[](const PythonFutureWrapper& /* unused */) {
TORCH_CHECK(false, "Can not pickle torch.futures.Future");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy Pybind API's
// requirement.
return py::make_tuple();
},
/* __setstate__ */
[](const py::tuple& /* unused */) { // NOLINT
TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
// Note that this return has no meaning since we always
// throw, it's only here to satisfy PyBind's API
// requirement.
return nullptr;
}),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) { m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {

View File

@ -71,7 +71,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
c10::optional<UnwrapFunc> unwrap_func = c10::nullopt) c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
: fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {} : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
PythonFutureWrapper(const PythonFutureWrapper&) = delete; explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete; PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
py::object wait() { py::object wait() {
@ -136,6 +136,14 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
PyObjectType::get())); PyObjectType::get()));
} }
void markCompleted(const py::object& pyValue) {
DCHECK(PyGILState_Check());
IValue value = toIValue(pyValue, PyObjectType::get());
py::gil_scoped_release release;
fut->markCompleted(std::move(value));
}
c10::intrusive_ptr<c10::ivalue::Future> fut; c10::intrusive_ptr<c10::ivalue::Future> fut;
// unwrap_func works like a callback for the value returned by // unwrap_func works like a callback for the value returned by
// PythonFutureWrapper::wait(). // PythonFutureWrapper::wait().

96
torch/futures/__init__.py Normal file
View File

@ -0,0 +1,96 @@
"""
The ``torch.futures`` package contains a ``Future`` type and corresponding
utility functions.
"""
import torch
class Future(torch._C.Future):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
.. warning::
The ``torch.futures.Future`` is experimental and subject to change.
"""
def __new__(cls):
return super(Future, cls).__new__(cls)
def wait(self):
r"""
Block until the value of this ``Future`` is ready.
Returns:
The value held by this ``Future``. If the function (callback or RPC)
creating the value has thrown an error, this ``wait`` method will
also throw an error.
"""
return super(Future, self).wait()
def then(self, callback):
r"""
Append the given callback function to this ``Future``, which will be run
when the ``Future`` is completed. Multiple callbacks can be added to
the same ``Future``, and will be invoked in the same order as they were
added. The callback must take one argument, which is the reference to
this ``Future``. The callback function can use the ``Future.wait()`` API
to get the value.
Arguments:
callback(``Callable``): a ``Callable`` that takes this ``Future`` as
the only argument.
Returns:
A new ``Future`` object that holds the return value of the
``callback`` and will be marked as completed when the given
``callback`` finishes.
Example::
>>> import torch
>>>
>>> def callback(fut):
>>> print(f"RPC return value is {fut.wait()}.")
>>>
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(lambda x : print(f"Chained cb done. {x.wait()}"))
>>> fut.set_result(5)
>>>
>>> # Outputs are:
>>> # RPC return value is 5.
>>> # Chained cb done. None
"""
return super(Future, self).then(callback)
def set_result(self, result):
r"""
Set the result for this ``Future``, which will mark this ``Future`` as
completed and trigger all attached callbacks. Note that a ``Future``
cannot be marked completed twice.
Arguments:
result (object): the result object of this ``Future``.
Example::
>>> import threading
>>> import time
>>> import torch
>>>
>>> def slow_set_future(fut, value):
>>> time.sleep(0.5)
>>> fut.set_result(value)
>>>
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
>>> target=slow_set_future,
>>> args=(fut, torch.ones(2) * 3)
>>> )
>>> t.start()
>>>
>>> print(fut.wait()) # tensor([3., 3.])
>>> t.join()
"""
super(Future, self).set_result(result)

View File

@ -944,7 +944,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
worker_name((self.rank + 1) % self.world_size), worker_name((self.rank + 1) % self.world_size),
script_fork_wait_udf, script_fork_wait_udf,
args=(torch.ones(2),) args=(torch.ones(2),)
)._then(callback) ).then(callback)
self.assertEqual(future.wait(), torch.ones(2) * 2 + 1) self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
@dist_init @dist_init
@ -963,7 +963,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
num_cbs = 20 num_cbs = 20
for _ in range(num_cbs): for _ in range(num_cbs):
fut = fut._then(callback) fut = fut.then(callback)
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
@ -988,7 +988,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
worker_name((self.rank + 1) % self.world_size), worker_name((self.rank + 1) % self.world_size),
script_fork_wait_throw, script_fork_wait_throw,
args=(torch.ones(2),) args=(torch.ones(2),)
)._then(callback) ).then(callback)
with self.assertRaisesRegex(RuntimeError, "Another expected error"): with self.assertRaisesRegex(RuntimeError, "Another expected error"):
future.wait() future.wait()

View File

@ -312,7 +312,7 @@ def add_use_future_cb(to, x, y, z):
out.set_result(fut.wait() + z) out.set_result(fut.wait() + z)
fut = rpc.rpc_async(to, torch.add, args=(x, y)) fut = rpc.rpc_async(to, torch.add, args=(x, y))
fut._then(callback) fut.then(callback)
return out.result() return out.result()
@ -320,6 +320,29 @@ def get_events_from_profile(profile_rref):
return profile_rref.local_value().process_global_function_events return profile_rref.local_value().process_global_function_events
def add_use_future_set_result(to, x, y, z):
out = torch.futures.Future()
fut = rpc.rpc_async(to, torch.add, args=(x, y))
fut.then(lambda fut : out.set_result(fut.wait() + z))
return out.wait()
def add_use_future_nested_cb(to, x, y, z):
out = torch.futures.Future()
def callback(fut1):
fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
fut2.then(lambda fut2 : out.set_result(fut2.wait()))
fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
fut1.then(callback)
return out.wait()
def fail_on_fut(fut):
pass
# load_tests from common_utils is used to automatically filter tests for # load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings # sharding on sandcastle. This line silences flake warnings
load_tests = load_tests load_tests = load_tests
@ -2481,7 +2504,7 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)) args=(torch.ones(n, n), torch.ones(n, n))
) )
fut._then(callback) fut.then(callback)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1) self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
@ -2498,7 +2521,7 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)) args=(torch.ones(n, n), torch.ones(n, n))
) )
cb_fut = fut._then(my_function) cb_fut = fut.then(my_function)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@ -2513,7 +2536,7 @@ class RpcTest(RpcAgentTestFixture):
dst = worker_name((self.rank + 1) % self.world_size) dst = worker_name((self.rank + 1) % self.world_size)
fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1)) fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
fut1 = fut0._then(lambda x: x + 1) fut1 = fut0.then(lambda x: x + 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
@ -2539,7 +2562,7 @@ class RpcTest(RpcAgentTestFixture):
cb_futs = [] cb_futs = []
for idx in range(num_cbs): for idx in range(num_cbs):
cb_futs.append(fut._then(partial(callback, idx))) cb_futs.append(fut.then(partial(callback, idx)))
self.assertEqual(fut.wait(), torch.ones(n, n) * 2) self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@ -2567,7 +2590,7 @@ class RpcTest(RpcAgentTestFixture):
num_cbs = 20 num_cbs = 20
for _ in range(num_cbs): for _ in range(num_cbs):
fut = fut._then(callback) fut = fut.then(callback)
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs) self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
@ -2592,7 +2615,7 @@ class RpcTest(RpcAgentTestFixture):
dst, dst,
torch.add, torch.add,
args=(fut0.wait(), 1) args=(fut0.wait(), 1)
)._then(lambda fut1: fut1.wait() + 1) ).then(lambda fut1: fut1.wait() + 1)
return fut2.wait() return fut2.wait()
@ -2600,7 +2623,7 @@ class RpcTest(RpcAgentTestFixture):
dst, dst,
torch.add, torch.add,
args=(torch.ones(2, 2), 1) args=(torch.ones(2, 2), 1)
)._then(callback) ).then(callback)
self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3) self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
@ -2613,7 +2636,7 @@ class RpcTest(RpcAgentTestFixture):
fut0.wait() fut0.wait()
raise RuntimeError("Another expected error") raise RuntimeError("Another expected error")
fut1 = rpc.rpc_async(dst, raise_func)._then(callback) fut1 = rpc.rpc_async(dst, raise_func).then(callback)
with self.assertRaisesRegex(RuntimeError, "Another expected error"): with self.assertRaisesRegex(RuntimeError, "Another expected error"):
fut1.wait() fut1.wait()
@ -2625,7 +2648,58 @@ class RpcTest(RpcAgentTestFixture):
TypeError, TypeError,
"incompatible function arguments." "incompatible function arguments."
): ):
rpc.rpc_async(dst, raise_func)._then(None) rpc.rpc_async(dst, raise_func).then(None)
@dist_init
def test_mark_future_twice(self):
fut = rpc.rpc_async(
worker_name((self.rank + 1) % self.world_size),
torch.add,
args=(torch.zeros(2, 2), 1)
)
self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
@dist_init
def test_pickle_future(self):
fut = torch.futures.Future()
errMsg = "Can not pickle torch.futures.Future"
dst = worker_name((self.rank + 1) % self.world_size)
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
rpc.rpc_async(dst, fail_on_fut, args=(fut,))
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
rpc.remote(dst, fail_on_fut, args=(fut,))
def _test_future_cb(self, func):
dst1 = worker_name((self.rank + 1) % self.world_size)
dst2 = worker_name((self.rank + 2) % self.world_size)
ret = rpc.rpc_sync(
dst1,
func,
args=(dst2, torch.ones(2, 2), 1, 2)
)
self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
@dist_init
def test_future_in_rpc(self):
self._test_future_cb(add_use_future_set_result)
@dist_init
def test_future_nested_callback(self):
self._test_future_cb(add_use_future_nested_cb)
class FaultyAgentRpcTest(FaultyRpcAgentTestFixture): class FaultyAgentRpcTest(FaultyRpcAgentTestFixture):