mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expose torch.futures.Future (#39008)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39008 This commit adds a `torch.futures.Future` type and exposes its ctor, `wait`, `then`, and `set_result` APIs. This type is currently a wrapper of `c10::ivalue::Future` and mainly used by RPC for now. Later, we could revamp c10d APIs to return this `Future` type as well. More utils will be added into `torch.futures` package in followup PRs. Test Plan: Imported from OSS Differential Revision: D21723022 Pulled By: mrshenli fbshipit-source-id: 92e56160544e9bf00d11db3e8347a1b9707882c9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b3fac8af6b
commit
bb0377bb24
@ -259,7 +259,10 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
|||||||
*/
|
*/
|
||||||
void markCompleted(IValue value) {
|
void markCompleted(IValue value) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
AT_ASSERT(!completed());
|
TORCH_CHECK(
|
||||||
|
!completed(),
|
||||||
|
"Attempting to mark a completed Future as complete again. Note that "
|
||||||
|
"a Future can only be marked completed once.");
|
||||||
completed_ = true;
|
completed_ = true;
|
||||||
value_ = std::move(value);
|
value_ = std::move(value);
|
||||||
|
|
||||||
|
3
mypy.ini
3
mypy.ini
@ -49,6 +49,9 @@ ignore_errors = True
|
|||||||
[mypy-torch.functional.*]
|
[mypy-torch.functional.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-torch.futures.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.testing._internal.*]
|
[mypy-torch.testing._internal.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
@ -79,6 +79,7 @@ TESTS = [
|
|||||||
'test_determination',
|
'test_determination',
|
||||||
'distributed/rpc/jit/test_rpc_spawn',
|
'distributed/rpc/jit/test_rpc_spawn',
|
||||||
'distributed/rpc/faulty_agent/test_rpc_spawn',
|
'distributed/rpc/faulty_agent/test_rpc_spawn',
|
||||||
|
'test_futures',
|
||||||
]
|
]
|
||||||
|
|
||||||
WINDOWS_BLACKLIST = [
|
WINDOWS_BLACKLIST = [
|
||||||
@ -154,6 +155,7 @@ SLOW_TESTS = [
|
|||||||
'distributed/test_c10d_spawn',
|
'distributed/test_c10d_spawn',
|
||||||
'test_quantization',
|
'test_quantization',
|
||||||
'test_determination',
|
'test_determination',
|
||||||
|
'test_futures',
|
||||||
]
|
]
|
||||||
_DEP_MODULES_CACHE = {}
|
_DEP_MODULES_CACHE = {}
|
||||||
|
|
||||||
|
98
test/test_futures.py
Normal file
98
test/test_futures.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from torch.futures import Future
|
||||||
|
from torch.testing._internal.common_utils import TestCase, TemporaryFileName
|
||||||
|
|
||||||
|
|
||||||
|
def add_one(fut):
|
||||||
|
return fut.wait() + 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestFuture(TestCase):
|
||||||
|
def test_wait(self):
|
||||||
|
f = Future()
|
||||||
|
f.set_result(torch.ones(2, 2))
|
||||||
|
|
||||||
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
||||||
|
|
||||||
|
def test_wait_multi_thread(self):
|
||||||
|
|
||||||
|
def slow_set_future(fut, value):
|
||||||
|
time.sleep(0.5)
|
||||||
|
fut.set_result(value)
|
||||||
|
|
||||||
|
f = Future()
|
||||||
|
|
||||||
|
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
def test_mark_future_twice(self):
|
||||||
|
fut = Future()
|
||||||
|
fut.set_result(1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Future can only be marked completed once"
|
||||||
|
):
|
||||||
|
fut.set_result(1)
|
||||||
|
|
||||||
|
def test_pickle_future(self):
|
||||||
|
fut = Future()
|
||||||
|
errMsg = "Can not pickle torch.futures.Future"
|
||||||
|
with TemporaryFileName() as fname:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||||
|
torch.save(fut, fname)
|
||||||
|
|
||||||
|
def test_then(self):
|
||||||
|
fut = Future()
|
||||||
|
then_fut = fut.then(lambda x: x.wait() + 1)
|
||||||
|
|
||||||
|
fut.set_result(torch.ones(2, 2))
|
||||||
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
||||||
|
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
|
||||||
|
|
||||||
|
def test_chained_then(self):
|
||||||
|
fut = Future()
|
||||||
|
futs = []
|
||||||
|
last_fut = fut
|
||||||
|
for _ in range(20):
|
||||||
|
last_fut = last_fut.then(add_one)
|
||||||
|
futs.append(last_fut)
|
||||||
|
|
||||||
|
fut.set_result(torch.ones(2, 2))
|
||||||
|
|
||||||
|
for i in range(len(futs)):
|
||||||
|
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
|
||||||
|
|
||||||
|
def _test_error(self, cb, errMsg):
|
||||||
|
fut = Future()
|
||||||
|
then_fut = fut.then(cb)
|
||||||
|
|
||||||
|
fut.set_result(5)
|
||||||
|
self.assertEqual(5, fut.wait())
|
||||||
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||||
|
then_fut.wait()
|
||||||
|
|
||||||
|
def test_then_wrong_arg(self):
|
||||||
|
|
||||||
|
def wrong_arg(tensor):
|
||||||
|
return tensor + 1
|
||||||
|
|
||||||
|
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
|
||||||
|
|
||||||
|
def test_then_no_arg(self):
|
||||||
|
|
||||||
|
def no_arg():
|
||||||
|
return True
|
||||||
|
|
||||||
|
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
|
||||||
|
|
||||||
|
def test_then_raise(self):
|
||||||
|
|
||||||
|
def raise_value_error(fut):
|
||||||
|
raise ValueError("Expected error")
|
||||||
|
|
||||||
|
self._test_error(raise_value_error, "Expected error")
|
@ -412,6 +412,7 @@ del ComplexFloatStorageBase
|
|||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.autograd
|
import torch.autograd
|
||||||
from torch.autograd import no_grad, enable_grad, set_grad_enabled
|
from torch.autograd import no_grad, enable_grad, set_grad_enabled
|
||||||
|
import torch.futures
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.nn.intrinsic
|
import torch.nn.intrinsic
|
||||||
import torch.nn.quantized
|
import torch.nn.quantized
|
||||||
|
@ -833,13 +833,40 @@ void initJITBindings(PyObject* module) {
|
|||||||
|
|
||||||
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
|
py::class_<PythonFutureWrapper, std::shared_ptr<PythonFutureWrapper>>(
|
||||||
m, "Future")
|
m, "Future")
|
||||||
|
.def(py::init([]() {
|
||||||
|
return std::make_shared<PythonFutureWrapper>(
|
||||||
|
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
|
||||||
|
}))
|
||||||
.def(
|
.def(
|
||||||
"wait",
|
"wait",
|
||||||
&PythonFutureWrapper::wait,
|
&PythonFutureWrapper::wait,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
"_then",
|
"then",
|
||||||
&PythonFutureWrapper::then,
|
&PythonFutureWrapper::then,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(
|
||||||
|
"set_result",
|
||||||
|
// Intentionally not releasing GIL
|
||||||
|
&PythonFutureWrapper::markCompleted)
|
||||||
|
.def(
|
||||||
|
py::pickle(
|
||||||
|
/* __getstate__ */
|
||||||
|
[](const PythonFutureWrapper& /* unused */) {
|
||||||
|
TORCH_CHECK(false, "Can not pickle torch.futures.Future");
|
||||||
|
// Note that this return has no meaning since we always
|
||||||
|
// throw, it's only here to satisfy Pybind API's
|
||||||
|
// requirement.
|
||||||
|
return py::make_tuple();
|
||||||
|
},
|
||||||
|
/* __setstate__ */
|
||||||
|
[](const py::tuple& /* unused */) { // NOLINT
|
||||||
|
TORCH_CHECK(false, "Can not unpickle torch.futures.Future");
|
||||||
|
// Note that this return has no meaning since we always
|
||||||
|
// throw, it's only here to satisfy PyBind's API
|
||||||
|
// requirement.
|
||||||
|
return nullptr;
|
||||||
|
}),
|
||||||
py::call_guard<py::gil_scoped_release>());
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
|
m.def("fork", [](const py::args& args, const py::kwargs& kwargs) {
|
||||||
|
@ -71,7 +71,7 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
|
|||||||
c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
|
c10::optional<UnwrapFunc> unwrap_func = c10::nullopt)
|
||||||
: fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
|
: fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
|
||||||
|
|
||||||
PythonFutureWrapper(const PythonFutureWrapper&) = delete;
|
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
|
||||||
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
|
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
|
||||||
|
|
||||||
py::object wait() {
|
py::object wait() {
|
||||||
@ -136,6 +136,14 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
|
|||||||
PyObjectType::get()));
|
PyObjectType::get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void markCompleted(const py::object& pyValue) {
|
||||||
|
DCHECK(PyGILState_Check());
|
||||||
|
IValue value = toIValue(pyValue, PyObjectType::get());
|
||||||
|
|
||||||
|
py::gil_scoped_release release;
|
||||||
|
fut->markCompleted(std::move(value));
|
||||||
|
}
|
||||||
|
|
||||||
c10::intrusive_ptr<c10::ivalue::Future> fut;
|
c10::intrusive_ptr<c10::ivalue::Future> fut;
|
||||||
// unwrap_func works like a callback for the value returned by
|
// unwrap_func works like a callback for the value returned by
|
||||||
// PythonFutureWrapper::wait().
|
// PythonFutureWrapper::wait().
|
||||||
|
96
torch/futures/__init__.py
Normal file
96
torch/futures/__init__.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
The ``torch.futures`` package contains a ``Future`` type and corresponding
|
||||||
|
utility functions.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Future(torch._C.Future):
|
||||||
|
r"""
|
||||||
|
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
|
||||||
|
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||||||
|
also exposes a set of APIs to add callback functions and set results.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
The ``torch.futures.Future`` is experimental and subject to change.
|
||||||
|
"""
|
||||||
|
def __new__(cls):
|
||||||
|
return super(Future, cls).__new__(cls)
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
r"""
|
||||||
|
Block until the value of this ``Future`` is ready.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The value held by this ``Future``. If the function (callback or RPC)
|
||||||
|
creating the value has thrown an error, this ``wait`` method will
|
||||||
|
also throw an error.
|
||||||
|
"""
|
||||||
|
return super(Future, self).wait()
|
||||||
|
|
||||||
|
def then(self, callback):
|
||||||
|
r"""
|
||||||
|
Append the given callback function to this ``Future``, which will be run
|
||||||
|
when the ``Future`` is completed. Multiple callbacks can be added to
|
||||||
|
the same ``Future``, and will be invoked in the same order as they were
|
||||||
|
added. The callback must take one argument, which is the reference to
|
||||||
|
this ``Future``. The callback function can use the ``Future.wait()`` API
|
||||||
|
to get the value.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
callback(``Callable``): a ``Callable`` that takes this ``Future`` as
|
||||||
|
the only argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new ``Future`` object that holds the return value of the
|
||||||
|
``callback`` and will be marked as completed when the given
|
||||||
|
``callback`` finishes.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> import torch
|
||||||
|
>>>
|
||||||
|
>>> def callback(fut):
|
||||||
|
>>> print(f"RPC return value is {fut.wait()}.")
|
||||||
|
>>>
|
||||||
|
>>> fut = torch.futures.Future()
|
||||||
|
>>> # The inserted callback will print the return value when
|
||||||
|
>>> # receiving the response from "worker1"
|
||||||
|
>>> cb_fut = fut.then(callback)
|
||||||
|
>>> chain_cb_fut = cb_fut.then(lambda x : print(f"Chained cb done. {x.wait()}"))
|
||||||
|
>>> fut.set_result(5)
|
||||||
|
>>>
|
||||||
|
>>> # Outputs are:
|
||||||
|
>>> # RPC return value is 5.
|
||||||
|
>>> # Chained cb done. None
|
||||||
|
"""
|
||||||
|
return super(Future, self).then(callback)
|
||||||
|
|
||||||
|
def set_result(self, result):
|
||||||
|
r"""
|
||||||
|
Set the result for this ``Future``, which will mark this ``Future`` as
|
||||||
|
completed and trigger all attached callbacks. Note that a ``Future``
|
||||||
|
cannot be marked completed twice.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
result (object): the result object of this ``Future``.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> import threading
|
||||||
|
>>> import time
|
||||||
|
>>> import torch
|
||||||
|
>>>
|
||||||
|
>>> def slow_set_future(fut, value):
|
||||||
|
>>> time.sleep(0.5)
|
||||||
|
>>> fut.set_result(value)
|
||||||
|
>>>
|
||||||
|
>>> fut = torch.futures.Future()
|
||||||
|
>>> t = threading.Thread(
|
||||||
|
>>> target=slow_set_future,
|
||||||
|
>>> args=(fut, torch.ones(2) * 3)
|
||||||
|
>>> )
|
||||||
|
>>> t.start()
|
||||||
|
>>>
|
||||||
|
>>> print(fut.wait()) # tensor([3., 3.])
|
||||||
|
>>> t.join()
|
||||||
|
"""
|
||||||
|
super(Future, self).set_result(result)
|
@ -944,7 +944,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
|
|||||||
worker_name((self.rank + 1) % self.world_size),
|
worker_name((self.rank + 1) % self.world_size),
|
||||||
script_fork_wait_udf,
|
script_fork_wait_udf,
|
||||||
args=(torch.ones(2),)
|
args=(torch.ones(2),)
|
||||||
)._then(callback)
|
).then(callback)
|
||||||
self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
|
self.assertEqual(future.wait(), torch.ones(2) * 2 + 1)
|
||||||
|
|
||||||
@dist_init
|
@dist_init
|
||||||
@ -963,7 +963,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
|
|||||||
|
|
||||||
num_cbs = 20
|
num_cbs = 20
|
||||||
for _ in range(num_cbs):
|
for _ in range(num_cbs):
|
||||||
fut = fut._then(callback)
|
fut = fut.then(callback)
|
||||||
|
|
||||||
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
|
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
|
||||||
|
|
||||||
@ -988,7 +988,7 @@ class JitRpcTest(RRefAPITest, RRefTypingTest, LocalRRefTest, JitRpcAsyncOpTest,
|
|||||||
worker_name((self.rank + 1) % self.world_size),
|
worker_name((self.rank + 1) % self.world_size),
|
||||||
script_fork_wait_throw,
|
script_fork_wait_throw,
|
||||||
args=(torch.ones(2),)
|
args=(torch.ones(2),)
|
||||||
)._then(callback)
|
).then(callback)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Another expected error"):
|
with self.assertRaisesRegex(RuntimeError, "Another expected error"):
|
||||||
future.wait()
|
future.wait()
|
||||||
|
@ -312,7 +312,7 @@ def add_use_future_cb(to, x, y, z):
|
|||||||
out.set_result(fut.wait() + z)
|
out.set_result(fut.wait() + z)
|
||||||
|
|
||||||
fut = rpc.rpc_async(to, torch.add, args=(x, y))
|
fut = rpc.rpc_async(to, torch.add, args=(x, y))
|
||||||
fut._then(callback)
|
fut.then(callback)
|
||||||
return out.result()
|
return out.result()
|
||||||
|
|
||||||
|
|
||||||
@ -320,6 +320,29 @@ def get_events_from_profile(profile_rref):
|
|||||||
return profile_rref.local_value().process_global_function_events
|
return profile_rref.local_value().process_global_function_events
|
||||||
|
|
||||||
|
|
||||||
|
def add_use_future_set_result(to, x, y, z):
|
||||||
|
out = torch.futures.Future()
|
||||||
|
fut = rpc.rpc_async(to, torch.add, args=(x, y))
|
||||||
|
fut.then(lambda fut : out.set_result(fut.wait() + z))
|
||||||
|
return out.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def add_use_future_nested_cb(to, x, y, z):
|
||||||
|
out = torch.futures.Future()
|
||||||
|
|
||||||
|
def callback(fut1):
|
||||||
|
fut2 = rpc.rpc_async(to, torch.add, args=(fut1.wait(), z))
|
||||||
|
fut2.then(lambda fut2 : out.set_result(fut2.wait()))
|
||||||
|
|
||||||
|
fut1 = rpc.rpc_async(to, torch.add, args=(x, y))
|
||||||
|
fut1.then(callback)
|
||||||
|
return out.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def fail_on_fut(fut):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# load_tests from common_utils is used to automatically filter tests for
|
# load_tests from common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
load_tests = load_tests
|
load_tests = load_tests
|
||||||
@ -2481,7 +2504,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
args=(torch.ones(n, n), torch.ones(n, n))
|
args=(torch.ones(n, n), torch.ones(n, n))
|
||||||
)
|
)
|
||||||
|
|
||||||
fut._then(callback)
|
fut.then(callback)
|
||||||
|
|
||||||
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
||||||
self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
|
self.assertEqual(set_by_cb.result(), torch.ones(n, n) * 2 + 1)
|
||||||
@ -2498,7 +2521,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
args=(torch.ones(n, n), torch.ones(n, n))
|
args=(torch.ones(n, n), torch.ones(n, n))
|
||||||
)
|
)
|
||||||
|
|
||||||
cb_fut = fut._then(my_function)
|
cb_fut = fut.then(my_function)
|
||||||
|
|
||||||
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
||||||
|
|
||||||
@ -2513,7 +2536,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
dst = worker_name((self.rank + 1) % self.world_size)
|
dst = worker_name((self.rank + 1) % self.world_size)
|
||||||
|
|
||||||
fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
|
fut0 = rpc.rpc_async(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||||
fut1 = fut0._then(lambda x: x + 1)
|
fut1 = fut0.then(lambda x: x + 1)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
@ -2539,7 +2562,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
|
|
||||||
cb_futs = []
|
cb_futs = []
|
||||||
for idx in range(num_cbs):
|
for idx in range(num_cbs):
|
||||||
cb_futs.append(fut._then(partial(callback, idx)))
|
cb_futs.append(fut.then(partial(callback, idx)))
|
||||||
|
|
||||||
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
||||||
|
|
||||||
@ -2567,7 +2590,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
|
|
||||||
num_cbs = 20
|
num_cbs = 20
|
||||||
for _ in range(num_cbs):
|
for _ in range(num_cbs):
|
||||||
fut = fut._then(callback)
|
fut = fut.then(callback)
|
||||||
|
|
||||||
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
|
self.assertEqual(fut.wait(), torch.ones(n, n) + 1 + num_cbs)
|
||||||
|
|
||||||
@ -2592,7 +2615,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
dst,
|
dst,
|
||||||
torch.add,
|
torch.add,
|
||||||
args=(fut0.wait(), 1)
|
args=(fut0.wait(), 1)
|
||||||
)._then(lambda fut1: fut1.wait() + 1)
|
).then(lambda fut1: fut1.wait() + 1)
|
||||||
|
|
||||||
return fut2.wait()
|
return fut2.wait()
|
||||||
|
|
||||||
@ -2600,7 +2623,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
dst,
|
dst,
|
||||||
torch.add,
|
torch.add,
|
||||||
args=(torch.ones(2, 2), 1)
|
args=(torch.ones(2, 2), 1)
|
||||||
)._then(callback)
|
).then(callback)
|
||||||
|
|
||||||
self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
|
self.assertEqual(fut3.wait(), torch.ones(2, 2) + 3)
|
||||||
|
|
||||||
@ -2613,7 +2636,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
fut0.wait()
|
fut0.wait()
|
||||||
raise RuntimeError("Another expected error")
|
raise RuntimeError("Another expected error")
|
||||||
|
|
||||||
fut1 = rpc.rpc_async(dst, raise_func)._then(callback)
|
fut1 = rpc.rpc_async(dst, raise_func).then(callback)
|
||||||
with self.assertRaisesRegex(RuntimeError, "Another expected error"):
|
with self.assertRaisesRegex(RuntimeError, "Another expected error"):
|
||||||
fut1.wait()
|
fut1.wait()
|
||||||
|
|
||||||
@ -2625,7 +2648,58 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
TypeError,
|
TypeError,
|
||||||
"incompatible function arguments."
|
"incompatible function arguments."
|
||||||
):
|
):
|
||||||
rpc.rpc_async(dst, raise_func)._then(None)
|
rpc.rpc_async(dst, raise_func).then(None)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_mark_future_twice(self):
|
||||||
|
fut = rpc.rpc_async(
|
||||||
|
worker_name((self.rank + 1) % self.world_size),
|
||||||
|
torch.add,
|
||||||
|
args=(torch.zeros(2, 2), 1)
|
||||||
|
)
|
||||||
|
self.assertEqual(fut.wait(), torch.zeros(2, 2) + 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Future can only be marked completed once"
|
||||||
|
):
|
||||||
|
fut.set_result(1)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_pickle_future(self):
|
||||||
|
fut = torch.futures.Future()
|
||||||
|
errMsg = "Can not pickle torch.futures.Future"
|
||||||
|
|
||||||
|
dst = worker_name((self.rank + 1) % self.world_size)
|
||||||
|
with TemporaryFileName() as fname:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||||
|
rpc.rpc_sync(dst, fail_on_fut, args=(fut,))
|
||||||
|
|
||||||
|
with TemporaryFileName() as fname:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||||
|
rpc.rpc_async(dst, fail_on_fut, args=(fut,))
|
||||||
|
|
||||||
|
with TemporaryFileName() as fname:
|
||||||
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||||
|
rpc.remote(dst, fail_on_fut, args=(fut,))
|
||||||
|
|
||||||
|
def _test_future_cb(self, func):
|
||||||
|
dst1 = worker_name((self.rank + 1) % self.world_size)
|
||||||
|
dst2 = worker_name((self.rank + 2) % self.world_size)
|
||||||
|
|
||||||
|
ret = rpc.rpc_sync(
|
||||||
|
dst1,
|
||||||
|
func,
|
||||||
|
args=(dst2, torch.ones(2, 2), 1, 2)
|
||||||
|
)
|
||||||
|
self.assertEqual(ret, torch.ones(2, 2) + 1 + 2)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_future_in_rpc(self):
|
||||||
|
self._test_future_cb(add_use_future_set_result)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_future_nested_callback(self):
|
||||||
|
self._test_future_cb(add_use_future_nested_cb)
|
||||||
|
|
||||||
|
|
||||||
class FaultyAgentRpcTest(FaultyRpcAgentTestFixture):
|
class FaultyAgentRpcTest(FaultyRpcAgentTestFixture):
|
||||||
|
Reference in New Issue
Block a user