Revert D21976891: [futures] Add torch.futures.collect_all()/wait_all() python api.

Test Plan: revert-hammer

Differential Revision:
D21976891

Original commit changeset: 253c61f503f4

fbshipit-source-id: f839b16f4469e96325b607b6313a1397e1988856
This commit is contained in:
Mike Ruberry
2020-06-12 13:38:46 -07:00
committed by Facebook GitHub Bot
parent 14099374bd
commit 8bc821f0d0
4 changed files with 11 additions and 50 deletions

View File

@ -956,21 +956,6 @@ void initJITBindings(PyObject* module) {
return fut->wait();
});
m.def(
"_collect_all",
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
-> std::shared_ptr<jit::PythonFutureWrapper> {
auto typePtr =
futures.empty() ? AnyType::get() : futures[0]->fut->elementType();
c10::List<c10::intrusive_ptr<ivalue::Future>> asList(FutureType::create(typePtr));
asList.reserve(futures.size());
for (const auto& f : futures) {
asList.push_back(f->fut);
}
return std::make_shared<jit::PythonFutureWrapper>(
c10::collectAll(asList));
});
m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
toIValue(obj, type);
});

View File

@ -94,30 +94,3 @@ class Future(torch._C.Future):
>>> t.join()
"""
super(Future, self).set_result(result)
def collect_all(futures):
r"""
Collects the Futures into a single combined Future that is completed
when all of the sub-futures are completed.
Arguments:
futures: a list of Futures
Returns:
Returns a Future object to a list of the passed in Futures.
"""
return torch._C._collect_all(futures)
def wait_all(futures):
r"""
Waits for all provided futures to be complete, and returns
the list of completed values.
Arguments:
futures: a list of Futures
Returns:
A list of the completed Future results
"""
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]

View File

@ -1045,7 +1045,8 @@ class DistAutogradTest(RpcAgentTestFixture):
_set_rpc_done(None, 0)
# wait until all trainers are done
torch.futures.wait_all(futures)
for fut in futures:
fut.wait()
@dist_init
def test_trainer_ps(self):

View File

@ -822,7 +822,8 @@ class RpcTest(RpcAgentTestFixture):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
for fut in torch.futures.collect_all(futs).wait():
for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)
# Phase 2: Only worker2 has workload.
@ -834,8 +835,9 @@ class RpcTest(RpcAgentTestFixture):
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
futs.append(fut)
for val in torch.futures.wait_all(futs):
self.assertEqual(val, 0)
for fut in futs:
fut.wait()
self.assertEqual(fut.wait(), 0)
def test_wait_all_workers(self):
rpc.init_rpc(
@ -1240,9 +1242,9 @@ class RpcTest(RpcAgentTestFixture):
futs.append(fut)
j = 0
for val in torch.futures.wait_all(futs):
for fut in futs:
self.assertEqual(
val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
)
j += 1
@ -1308,8 +1310,8 @@ class RpcTest(RpcAgentTestFixture):
fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
futs.append(fut)
for val in torch.futures.wait_all(futs):
self.assertEqual(val, 0)
for fut in futs:
self.assertEqual(fut.wait(), 0)
tok = time.time()
print(
"Rank {} finished testing {} times in {} seconds.".format(