Fix RPC and ProcessGroup GIL deadlock (#45088)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45088

Fixes #45082

Found a few problems while working on #44983

1. We deliberately swallow RPC timeouts during shutdown, as we haven't
found a good way to handle those. When we convert `_wait_all_workers`
into `_all_gather`, the same logic was inherited. However, as
`_all_gather` meant to be used in more general scenarios, we should
no longer keep silent about errors. This commit let the error throw
in `_all_gather` and also let `shutdown()` to catch them and log.
2. After fixing (1), I found that `UnpickledPythonCall` needs to
acquire GIL on destruction, and this can lead to deadlock when used
in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a
synchronization point which holds GIL. In `init_rpc`, followers
(`rank != 0`) can exit before the leader (`rank == 0`). If the two
happens together, we could get a) on a follower, it exits `init_rpc`
after running `_broadcast_to_followers` and before the reaching dtor
of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`,
which holds the GIL and wait for the leader to join. However, the
leader is waiting for the response from `_broadcast_to_followers`,
which is blocked by the dtor of `UnpickledPythonCall`. And hence
the deadlock. This commit drops the GIL in `ProcessGroup` ctor.
3. After fixing (2), I found that `TensorPipe` backend
nondeterministically fails with `test_local_shutdown`, due to a
similar reason as (2), but this time it is that `shutdown()` on a
follower runs before the leader finishes `init_rpc`. This commit
adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`.

The 3rd one should be able to solve the 2nd one as well. But since
I didn't see a reason to hold GIL during `ProcessGroup` ctor, I
made that change too.

Test Plan: Imported from OSS

Reviewed By: pritamdamania87

Differential Revision: D23825592

Pulled By: mrshenli

fbshipit-source-id: 94920f2ad357746a6b8e4ffaa380dd56a7310976
This commit is contained in:
Shen Li
2020-09-21 21:44:56 -07:00
committed by Facebook GitHub Bot
parent dfc88d4fd0
commit 09e7f62ce2
4 changed files with 49 additions and 13 deletions

View File

@ -653,7 +653,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupGloo::Options>())
::c10d::ProcessGroupGloo::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const std::shared_ptr<::c10d::Store>& store,
int rank,
@ -684,7 +685,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = std::chrono::milliseconds(10 * 1000)); // NOLINT
py::arg("timeout") = std::chrono::milliseconds(10 * 1000), // NOLINT
py::call_guard<py::gil_scoped_release>());
#endif
#ifdef USE_C10D_NCCL
@ -694,7 +696,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
const std::shared_ptr<::c10d::Store>&,
int,
int,
::c10d::ProcessGroupNCCL::Options>())
::c10d::ProcessGroupNCCL::Options>(),
py::call_guard<py::gil_scoped_release>())
.def(
py::init([](const std::shared_ptr<::c10d::Store>& store,
int rank,
@ -710,7 +713,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
py::arg("rank"),
py::arg("size"),
py::arg("timeout") = std::chrono::milliseconds(
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis));
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis),
py::call_guard<py::gil_scoped_release>());
py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
.def(py::init<>())
@ -725,9 +729,12 @@ They are used in specifying strategies for reduction collectives, e.g.,
// Define static create function instead of a constructor, because
// this function may return null. This happens if this process is not
// part of a sub group that is to be created.
processGroupMPI.def_static("create", [](std::vector<int> ranks) {
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
});
processGroupMPI.def_static(
"create",
[](std::vector<int> ranks) {
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
},
py::call_guard<py::gil_scoped_release>());
#endif
shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")

View File

@ -196,15 +196,20 @@ def _all_gather(obj):
timeout=timeout
)
worker_name_to_response_future_dict[follower_name] = fut
errors = []
for follower_name, fut in worker_name_to_response_future_dict.items():
try:
fut.wait()
except RuntimeError as ex:
logger.error(
"{worker_name} failed to respond to 'Shutdown Proceed.' request in {timeout}".format(
worker_name=follower_name, timeout=timeout
)
)
errors.append((follower_name, ex))
if errors:
raise RuntimeError(
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
f"after {timeout} seconds. The first exception is {errors[0][1]}"
)
return states.gathered_objects
@ -217,7 +222,12 @@ def _wait_all_workers():
terminate the RPC framework, and there is no guarantee that the RPC
framework will work after this method returns.
"""
_all_gather(None)
try:
_all_gather(None)
except RuntimeError as ex:
logger.error(
f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
)
@_require_initialized

View File

@ -254,6 +254,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_
try:
_tensorpipe_check_device_maps(agent, rpc_backend_options.device_maps)
agent.join()
except Exception:
api.shutdown()
raise

View File

@ -2406,6 +2406,24 @@ class RpcTest(RpcAgentTestFixture):
# exit all workers non-gracefully.
rpc.shutdown(graceful=False)
@dist_init
def test_deadlock(self):
# this test is copied from https://github.com/pytorch/pytorch/issues/45089
if self.rank == 1:
dst1 = worker_name((self.rank + 1) % self.world_size)
x = torch.ones(2)
y = torch.ones(2)
rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
dist_initialized = dist.is_initialized()
if not dist_initialized:
dist.init_process_group(
backend="gloo",
init_method=self.init_method,
rank=self.rank,
world_size=self.world_size,
)
@dist_init(setup_rpc=False)
def test_local_shutdown_with_rpc(self):
# test that we can start RPC, send RPCs, and then run local shutdown.