mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[futures] Add specific python unittest coverage for collect_all/wait_all (#40233)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40233 There was a question earlier whether torch.futures.wait_all() would raised if the underlying futures raise (it was supposed to, but no test coverage). This change adds a couple very basic torch.futures.collect_all/ wait_all tests. ghstack-source-id: 106168134 Test Plan: buck test mode/dev-nosan caffe2/test:futures Differential Revision: D22120284 fbshipit-source-id: 3a8edae5dbf8c58c8361eff156c386a684ec5e86
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8b5732e8ad
commit
35f357927d
@ -96,3 +96,39 @@ class TestFuture(TestCase):
|
||||
raise ValueError("Expected error")
|
||||
|
||||
self._test_error(raise_value_error, "Expected error")
|
||||
|
||||
def test_collect_all(self):
|
||||
fut1 = Future()
|
||||
fut2 = Future()
|
||||
fut_all = torch.futures.collect_all([fut1, fut2])
|
||||
|
||||
def slow_in_thread(fut, value):
|
||||
time.sleep(0.1)
|
||||
fut.set_result(value)
|
||||
|
||||
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
|
||||
fut2.set_result(2)
|
||||
t.start()
|
||||
|
||||
res = fut_all.wait()
|
||||
self.assertEqual(res[0].wait(), 1)
|
||||
self.assertEqual(res[1].wait(), 2)
|
||||
t.join()
|
||||
|
||||
def test_wait_all(self):
|
||||
fut1 = Future()
|
||||
fut2 = Future()
|
||||
|
||||
# No error version
|
||||
fut1.set_result(1)
|
||||
fut2.set_result(2)
|
||||
res = torch.futures.wait_all([fut1, fut2])
|
||||
print(res)
|
||||
self.assertEqual(res, [1, 2])
|
||||
|
||||
# Version with an exception
|
||||
def raise_in_fut(fut):
|
||||
raise ValueError("Expected error")
|
||||
fut3 = fut1.then(raise_in_fut)
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected error"):
|
||||
torch.futures.wait_all([fut3, fut2])
|
||||
|
Reference in New Issue
Block a user