mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50983 There is currently no way to handle/propagate errors with the python-based futures API (they are raised correctly if set with an error, but this is only possible from C++). This diff allows the Future's `unwrap_func` to be set in python optionally, so users can set futures completed with an exception and the error will throw as expected. This is mostly to support the following use case in the next diff: ``` ret_fut = torch.futures.Future(unwrap_func = lambda python_result: { # throw exception if needed if isinstance(python_result, Exception): throw python_result }) rpc_fut = rpc.rpc_async(...) # RPC future that times out # Goal is to propagate RPC error to this future rpc_fut.add_done_callback( res => { # Note that ret_fut.set_result(res.wait()) won't propagate the error try: ret_fut.set_result(res.wait()) except Exception as e: ret_fut.set_result(e) } ) ``` ghstack-source-id: 121021434 Test Plan: unittest ``` buck test mode/dev-nosan mode/no-gpu //caffe2/test:futures -- te st_unwrap --print-passing-details ``` Reviewed By: mrshenli Differential Revision: D25950304 fbshipit-source-id: 7ee61e98fcd783b3f515706fa141d538e6d2174d
330 lines
9.7 KiB
Python
330 lines
9.7 KiB
Python
import threading
|
|
import time
|
|
import torch
|
|
import unittest
|
|
from torch.futures import Future
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
|
|
from typing import TypeVar
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def add_one(fut):
|
|
return fut.wait() + 1
|
|
|
|
|
|
class TestFuture(TestCase):
|
|
def test_set_exception(self) -> None:
|
|
# This test is to ensure errors can propagate across futures.
|
|
error_msg = "Intentional Value Error"
|
|
value_error = ValueError(error_msg)
|
|
|
|
f = Future[T]()
|
|
# Set exception
|
|
f.set_exception(value_error) # type: ignore
|
|
# Exception should throw on wait
|
|
with self.assertRaisesRegex(ValueError, "Intentional"):
|
|
f.wait()
|
|
|
|
# Exception should also throw on value
|
|
f = Future()
|
|
f.set_exception(value_error) # type: ignore
|
|
with self.assertRaisesRegex(ValueError, "Intentional"):
|
|
f.value() # type: ignore
|
|
|
|
def cb(fut):
|
|
fut.value() # type: ignore
|
|
|
|
f = Future()
|
|
f.set_exception(value_error) # type: ignore
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
|
cb_fut = f.then(cb)
|
|
cb_fut.wait()
|
|
|
|
def test_set_exception_multithreading(self) -> None:
|
|
# Ensure errors can propagate when one thread waits on future result
|
|
# and the other sets it with an error.
|
|
error_msg = "Intentional Value Error"
|
|
value_error = ValueError(error_msg)
|
|
|
|
def wait_future(f):
|
|
with self.assertRaisesRegex(ValueError, "Intentional"):
|
|
f.wait()
|
|
|
|
f = Future[T]()
|
|
t = threading.Thread(target=wait_future, args=(f, ))
|
|
t.start()
|
|
f.set_exception(value_error) # type: ignore
|
|
t.join()
|
|
|
|
def cb(fut):
|
|
fut.value() # type: ignore
|
|
|
|
def then_future(f):
|
|
fut = f.then(cb)
|
|
with self.assertRaisesRegex(RuntimeError, "Got the following error"):
|
|
fut.wait()
|
|
|
|
f = Future[T]()
|
|
t = threading.Thread(target=then_future, args=(f, ))
|
|
t.start()
|
|
f.set_exception(value_error) # type: ignore
|
|
t.join()
|
|
|
|
def test_done(self) -> None:
|
|
f = Future[torch.Tensor]()
|
|
self.assertFalse(f.done())
|
|
|
|
f.set_result(torch.ones(2, 2))
|
|
self.assertTrue(f.done())
|
|
|
|
def test_done_exception(self) -> None:
|
|
err_msg = "Intentional Value Error"
|
|
|
|
def raise_exception(unused_future):
|
|
raise RuntimeError(err_msg)
|
|
|
|
f1 = Future[torch.Tensor]()
|
|
self.assertFalse(f1.done())
|
|
f1.set_result(torch.ones(2, 2))
|
|
self.assertTrue(f1.done())
|
|
|
|
f2 = f1.then(raise_exception)
|
|
self.assertTrue(f2.done())
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
f2.wait()
|
|
|
|
def test_wait(self) -> None:
|
|
f = Future[torch.Tensor]()
|
|
f.set_result(torch.ones(2, 2))
|
|
|
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
|
|
|
def test_wait_multi_thread(self) -> None:
|
|
|
|
def slow_set_future(fut, value):
|
|
time.sleep(0.5)
|
|
fut.set_result(value)
|
|
|
|
f = Future[torch.Tensor]()
|
|
|
|
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
|
|
t.start()
|
|
|
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
|
t.join()
|
|
|
|
def test_mark_future_twice(self) -> None:
|
|
fut = Future[int]()
|
|
fut.set_result(1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Future can only be marked completed once"
|
|
):
|
|
fut.set_result(1)
|
|
|
|
def test_pickle_future(self):
|
|
fut = Future[int]()
|
|
errMsg = "Can not pickle torch.futures.Future"
|
|
with TemporaryFileName() as fname:
|
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
|
torch.save(fut, fname)
|
|
|
|
def test_then(self):
|
|
fut = Future[torch.Tensor]()
|
|
then_fut = fut.then(lambda x: x.wait() + 1)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
|
|
|
|
def test_chained_then(self):
|
|
fut = Future[torch.Tensor]()
|
|
futs = []
|
|
last_fut = fut
|
|
for _ in range(20):
|
|
last_fut = last_fut.then(add_one)
|
|
futs.append(last_fut)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
|
|
for i in range(len(futs)):
|
|
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
|
|
|
|
def _test_then_error(self, cb, errMsg):
|
|
fut = Future[int]()
|
|
then_fut = fut.then(cb)
|
|
|
|
fut.set_result(5)
|
|
self.assertEqual(5, fut.wait())
|
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
|
then_fut.wait()
|
|
|
|
def test_then_wrong_arg(self):
|
|
|
|
def wrong_arg(tensor):
|
|
return tensor + 1
|
|
|
|
self._test_then_error(wrong_arg, "unsupported operand type.*Future.*int")
|
|
|
|
def test_then_no_arg(self):
|
|
|
|
def no_arg():
|
|
return True
|
|
|
|
self._test_then_error(no_arg, "takes 0 positional arguments but 1 was given")
|
|
|
|
def test_then_raise(self):
|
|
|
|
def raise_value_error(fut):
|
|
raise ValueError("Expected error")
|
|
|
|
self._test_then_error(raise_value_error, "Expected error")
|
|
|
|
def test_add_done_callback_simple(self):
|
|
callback_result = False
|
|
|
|
def callback(fut):
|
|
nonlocal callback_result
|
|
fut.wait()
|
|
callback_result = True
|
|
|
|
fut = Future[torch.Tensor]()
|
|
fut.add_done_callback(callback)
|
|
|
|
self.assertFalse(callback_result)
|
|
fut.set_result(torch.ones(2, 2))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
self.assertTrue(callback_result)
|
|
|
|
def test_add_done_callback_maintains_callback_order(self):
|
|
callback_result = 0
|
|
|
|
def callback_set1(fut):
|
|
nonlocal callback_result
|
|
fut.wait()
|
|
callback_result = 1
|
|
|
|
def callback_set2(fut):
|
|
nonlocal callback_result
|
|
fut.wait()
|
|
callback_result = 2
|
|
|
|
fut = Future[torch.Tensor]()
|
|
fut.add_done_callback(callback_set1)
|
|
fut.add_done_callback(callback_set2)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
# set2 called last, callback_result = 2
|
|
self.assertEqual(callback_result, 2)
|
|
|
|
def _test_add_done_callback_error_ignored(self, cb):
|
|
fut = Future[int]()
|
|
fut.add_done_callback(cb)
|
|
|
|
fut.set_result(5)
|
|
# error msg logged to stdout
|
|
self.assertEqual(5, fut.wait())
|
|
|
|
def test_add_done_callback_error_is_ignored(self):
|
|
|
|
def raise_value_error(fut):
|
|
raise ValueError("Expected error")
|
|
|
|
self._test_add_done_callback_error_ignored(raise_value_error)
|
|
|
|
def test_add_done_callback_no_arg_error_is_ignored(self):
|
|
|
|
def no_arg():
|
|
return True
|
|
|
|
# Adding another level of function indirection here on purpose.
|
|
# Otherwise mypy will pick up on no_arg having an incompatible type and fail CI
|
|
self._test_add_done_callback_error_ignored(no_arg)
|
|
|
|
def test_interleaving_then_and_add_done_callback_maintains_callback_order(self):
|
|
callback_result = 0
|
|
|
|
def callback_set1(fut):
|
|
nonlocal callback_result
|
|
fut.wait()
|
|
callback_result = 1
|
|
|
|
def callback_set2(fut):
|
|
nonlocal callback_result
|
|
fut.wait()
|
|
callback_result = 2
|
|
|
|
def callback_then(fut):
|
|
nonlocal callback_result
|
|
return fut.wait() + callback_result
|
|
|
|
fut = Future[torch.Tensor]()
|
|
fut.add_done_callback(callback_set1)
|
|
then_fut = fut.then(callback_then)
|
|
fut.add_done_callback(callback_set2)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
# then_fut's callback is called with callback_result = 1
|
|
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
|
|
# set2 called last, callback_result = 2
|
|
self.assertEqual(callback_result, 2)
|
|
|
|
def test_interleaving_then_and_add_done_callback_propagates_error(self):
|
|
def raise_value_error(fut):
|
|
raise ValueError("Expected error")
|
|
|
|
fut = Future[torch.Tensor]()
|
|
then_fut = fut.then(raise_value_error)
|
|
fut.add_done_callback(raise_value_error)
|
|
fut.set_result(torch.ones(2, 2))
|
|
|
|
# error from add_done_callback's callback is swallowed
|
|
# error from then's callback is not
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, "Expected error"):
|
|
then_fut.wait()
|
|
|
|
def test_collect_all(self):
|
|
fut1 = Future[int]()
|
|
fut2 = Future[int]()
|
|
fut_all = torch.futures.collect_all([fut1, fut2])
|
|
|
|
def slow_in_thread(fut, value):
|
|
time.sleep(0.1)
|
|
fut.set_result(value)
|
|
|
|
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
|
|
fut2.set_result(2)
|
|
t.start()
|
|
|
|
res = fut_all.wait()
|
|
self.assertEqual(res[0].wait(), 1)
|
|
self.assertEqual(res[1].wait(), 2)
|
|
t.join()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
|
|
def test_wait_all(self):
|
|
fut1 = Future[int]()
|
|
fut2 = Future[int]()
|
|
|
|
# No error version
|
|
fut1.set_result(1)
|
|
fut2.set_result(2)
|
|
res = torch.futures.wait_all([fut1, fut2])
|
|
print(res)
|
|
self.assertEqual(res, [1, 2])
|
|
|
|
# Version with an exception
|
|
def raise_in_fut(fut):
|
|
raise ValueError("Expected error")
|
|
fut3 = fut1.then(raise_in_fut)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected error"):
|
|
torch.futures.wait_all([fut3, fut2])
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|