mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D21976891: [futures] Add torch.futures.collect_all()/wait_all() python api.
Test Plan: revert-hammer Differential Revision: D21976891 Original commit changeset: 253c61f503f4 fbshipit-source-id: f839b16f4469e96325b607b6313a1397e1988856
This commit is contained in:
committed by
Facebook GitHub Bot
parent
14099374bd
commit
8bc821f0d0
@ -956,21 +956,6 @@ void initJITBindings(PyObject* module) {
|
||||
return fut->wait();
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_collect_all",
|
||||
[](const std::vector<std::shared_ptr<jit::PythonFutureWrapper>>& futures)
|
||||
-> std::shared_ptr<jit::PythonFutureWrapper> {
|
||||
auto typePtr =
|
||||
futures.empty() ? AnyType::get() : futures[0]->fut->elementType();
|
||||
c10::List<c10::intrusive_ptr<ivalue::Future>> asList(FutureType::create(typePtr));
|
||||
asList.reserve(futures.size());
|
||||
for (const auto& f : futures) {
|
||||
asList.push_back(f->fut);
|
||||
}
|
||||
return std::make_shared<jit::PythonFutureWrapper>(
|
||||
c10::collectAll(asList));
|
||||
});
|
||||
|
||||
m.def("_jit_assert_is_instance", [](py::object obj, TypePtr type) {
|
||||
toIValue(obj, type);
|
||||
});
|
||||
|
@ -94,30 +94,3 @@ class Future(torch._C.Future):
|
||||
>>> t.join()
|
||||
"""
|
||||
super(Future, self).set_result(result)
|
||||
|
||||
|
||||
def collect_all(futures):
|
||||
r"""
|
||||
Collects the Futures into a single combined Future that is completed
|
||||
when all of the sub-futures are completed.
|
||||
|
||||
Arguments:
|
||||
futures: a list of Futures
|
||||
|
||||
Returns:
|
||||
Returns a Future object to a list of the passed in Futures.
|
||||
"""
|
||||
return torch._C._collect_all(futures)
|
||||
|
||||
def wait_all(futures):
|
||||
r"""
|
||||
Waits for all provided futures to be complete, and returns
|
||||
the list of completed values.
|
||||
|
||||
Arguments:
|
||||
futures: a list of Futures
|
||||
|
||||
Returns:
|
||||
A list of the completed Future results
|
||||
"""
|
||||
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
|
||||
|
@ -1045,7 +1045,8 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
_set_rpc_done(None, 0)
|
||||
|
||||
# wait until all trainers are done
|
||||
torch.futures.wait_all(futures)
|
||||
for fut in futures:
|
||||
fut.wait()
|
||||
|
||||
@dist_init
|
||||
def test_trainer_ps(self):
|
||||
|
@ -822,7 +822,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
||||
futs.append(fut)
|
||||
|
||||
for fut in torch.futures.collect_all(futs).wait():
|
||||
for fut in futs:
|
||||
fut.wait()
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
|
||||
# Phase 2: Only worker2 has workload.
|
||||
@ -834,8 +835,9 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),))
|
||||
futs.append(fut)
|
||||
|
||||
for val in torch.futures.wait_all(futs):
|
||||
self.assertEqual(val, 0)
|
||||
for fut in futs:
|
||||
fut.wait()
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
|
||||
def test_wait_all_workers(self):
|
||||
rpc.init_rpc(
|
||||
@ -1240,9 +1242,9 @@ class RpcTest(RpcAgentTestFixture):
|
||||
futs.append(fut)
|
||||
|
||||
j = 0
|
||||
for val in torch.futures.wait_all(futs):
|
||||
for fut in futs:
|
||||
self.assertEqual(
|
||||
val, my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
||||
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
||||
)
|
||||
j += 1
|
||||
|
||||
@ -1308,8 +1310,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
fut = rpc.rpc_async(worker_name(dst_rank), f, args=args)
|
||||
futs.append(fut)
|
||||
|
||||
for val in torch.futures.wait_all(futs):
|
||||
self.assertEqual(val, 0)
|
||||
for fut in futs:
|
||||
self.assertEqual(fut.wait(), 0)
|
||||
tok = time.time()
|
||||
print(
|
||||
"Rank {} finished testing {} times in {} seconds.".format(
|
||||
|
Reference in New Issue
Block a user