[futures] Add specific python unittest coverage for collect_all/wait_all (#40233)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40233

There was a question earlier whether torch.futures.wait_all() would
raised if the underlying futures raise (it was supposed to, but no test
coverage). This change adds a couple very basic torch.futures.collect_all/
wait_all tests.
ghstack-source-id: 106168134

Test Plan: buck test mode/dev-nosan caffe2/test:futures

Differential Revision: D22120284

fbshipit-source-id: 3a8edae5dbf8c58c8361eff156c386a684ec5e86
This commit is contained in:
Jeremy Lilley
2020-06-18 16:12:20 -07:00
committed by Facebook GitHub Bot
parent 8b5732e8ad
commit 35f357927d

View File

@ -96,3 +96,39 @@ class TestFuture(TestCase):
raise ValueError("Expected error")
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future()
fut2 = Future()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
time.sleep(0.1)
fut.set_result(value)
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
fut2.set_result(2)
t.start()
res = fut_all.wait()
self.assertEqual(res[0].wait(), 1)
self.assertEqual(res[1].wait(), 2)
t.join()
def test_wait_all(self):
fut1 = Future()
fut2 = Future()
# No error version
fut1.set_result(1)
fut2.set_result(2)
res = torch.futures.wait_all([fut1, fut2])
print(res)
self.assertEqual(res, [1, 2])
# Version with an exception
def raise_in_fut(fut):
raise ValueError("Expected error")
fut3 = fut1.then(raise_in_fut)
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])