mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix RPC and ProcessGroup GIL deadlock (#45088)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45088 Fixes #45082 Found a few problems while working on #44983 1. We deliberately swallow RPC timeouts during shutdown, as we haven't found a good way to handle those. When we convert `_wait_all_workers` into `_all_gather`, the same logic was inherited. However, as `_all_gather` meant to be used in more general scenarios, we should no longer keep silent about errors. This commit let the error throw in `_all_gather` and also let `shutdown()` to catch them and log. 2. After fixing (1), I found that `UnpickledPythonCall` needs to acquire GIL on destruction, and this can lead to deadlock when used in conjuction with `ProcessGroup`. Because `ProcessGroup` ctor is a synchronization point which holds GIL. In `init_rpc`, followers (`rank != 0`) can exit before the leader (`rank == 0`). If the two happens together, we could get a) on a follower, it exits `init_rpc` after running `_broadcast_to_followers` and before the reaching dtor of `UnpickledPythonCall`. Then it runs the ctor of `ProcessGroup`, which holds the GIL and wait for the leader to join. However, the leader is waiting for the response from `_broadcast_to_followers`, which is blocked by the dtor of `UnpickledPythonCall`. And hence the deadlock. This commit drops the GIL in `ProcessGroup` ctor. 3. After fixing (2), I found that `TensorPipe` backend nondeterministically fails with `test_local_shutdown`, due to a similar reason as (2), but this time it is that `shutdown()` on a follower runs before the leader finishes `init_rpc`. This commit adds a join for `TensorPipe` backend `init_rpc` after `_all_gather`. The 3rd one should be able to solve the 2nd one as well. But since I didn't see a reason to hold GIL during `ProcessGroup` ctor, I made that change too. Test Plan: Imported from OSS Reviewed By: pritamdamania87 Differential Revision: D23825592 Pulled By: mrshenli fbshipit-source-id: 94920f2ad357746a6b8e4ffaa380dd56a7310976
This commit is contained in:
committed by
Facebook GitHub Bot
parent
dfc88d4fd0
commit
09e7f62ce2
@ -653,7 +653,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
const std::shared_ptr<::c10d::Store>&,
|
const std::shared_ptr<::c10d::Store>&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
::c10d::ProcessGroupGloo::Options>())
|
::c10d::ProcessGroupGloo::Options>(),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
py::init([](const std::shared_ptr<::c10d::Store>& store,
|
py::init([](const std::shared_ptr<::c10d::Store>& store,
|
||||||
int rank,
|
int rank,
|
||||||
@ -684,7 +685,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
py::arg("store"),
|
py::arg("store"),
|
||||||
py::arg("rank"),
|
py::arg("rank"),
|
||||||
py::arg("size"),
|
py::arg("size"),
|
||||||
py::arg("timeout") = std::chrono::milliseconds(10 * 1000)); // NOLINT
|
py::arg("timeout") = std::chrono::milliseconds(10 * 1000), // NOLINT
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_C10D_NCCL
|
#ifdef USE_C10D_NCCL
|
||||||
@ -694,7 +696,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
const std::shared_ptr<::c10d::Store>&,
|
const std::shared_ptr<::c10d::Store>&,
|
||||||
int,
|
int,
|
||||||
int,
|
int,
|
||||||
::c10d::ProcessGroupNCCL::Options>())
|
::c10d::ProcessGroupNCCL::Options>(),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
py::init([](const std::shared_ptr<::c10d::Store>& store,
|
py::init([](const std::shared_ptr<::c10d::Store>& store,
|
||||||
int rank,
|
int rank,
|
||||||
@ -710,7 +713,8 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
py::arg("rank"),
|
py::arg("rank"),
|
||||||
py::arg("size"),
|
py::arg("size"),
|
||||||
py::arg("timeout") = std::chrono::milliseconds(
|
py::arg("timeout") = std::chrono::milliseconds(
|
||||||
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis));
|
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis),
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
|
||||||
py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
|
py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
@ -725,9 +729,12 @@ They are used in specifying strategies for reduction collectives, e.g.,
|
|||||||
// Define static create function instead of a constructor, because
|
// Define static create function instead of a constructor, because
|
||||||
// this function may return null. This happens if this process is not
|
// this function may return null. This happens if this process is not
|
||||||
// part of a sub group that is to be created.
|
// part of a sub group that is to be created.
|
||||||
processGroupMPI.def_static("create", [](std::vector<int> ranks) {
|
processGroupMPI.def_static(
|
||||||
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
|
"create",
|
||||||
});
|
[](std::vector<int> ranks) {
|
||||||
|
return ::c10d::ProcessGroupMPI::createProcessGroupMPI(ranks);
|
||||||
|
},
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
|
shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
|
||||||
|
@ -196,15 +196,20 @@ def _all_gather(obj):
|
|||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
worker_name_to_response_future_dict[follower_name] = fut
|
worker_name_to_response_future_dict[follower_name] = fut
|
||||||
|
|
||||||
|
errors = []
|
||||||
for follower_name, fut in worker_name_to_response_future_dict.items():
|
for follower_name, fut in worker_name_to_response_future_dict.items():
|
||||||
try:
|
try:
|
||||||
fut.wait()
|
fut.wait()
|
||||||
except RuntimeError as ex:
|
except RuntimeError as ex:
|
||||||
logger.error(
|
errors.append((follower_name, ex))
|
||||||
"{worker_name} failed to respond to 'Shutdown Proceed.' request in {timeout}".format(
|
|
||||||
worker_name=follower_name, timeout=timeout
|
if errors:
|
||||||
)
|
raise RuntimeError(
|
||||||
)
|
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
|
||||||
|
f"after {timeout} seconds. The first exception is {errors[0][1]}"
|
||||||
|
)
|
||||||
|
|
||||||
return states.gathered_objects
|
return states.gathered_objects
|
||||||
|
|
||||||
|
|
||||||
@ -217,7 +222,12 @@ def _wait_all_workers():
|
|||||||
terminate the RPC framework, and there is no guarantee that the RPC
|
terminate the RPC framework, and there is no guarantee that the RPC
|
||||||
framework will work after this method returns.
|
framework will work after this method returns.
|
||||||
"""
|
"""
|
||||||
_all_gather(None)
|
try:
|
||||||
|
_all_gather(None)
|
||||||
|
except RuntimeError as ex:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@_require_initialized
|
@_require_initialized
|
||||||
|
@ -254,6 +254,7 @@ def _tensorpipe_init_backend_handler(store, name, rank, world_size, rpc_backend_
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
_tensorpipe_check_device_maps(agent, rpc_backend_options.device_maps)
|
_tensorpipe_check_device_maps(agent, rpc_backend_options.device_maps)
|
||||||
|
agent.join()
|
||||||
except Exception:
|
except Exception:
|
||||||
api.shutdown()
|
api.shutdown()
|
||||||
raise
|
raise
|
||||||
|
@ -2406,6 +2406,24 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
# exit all workers non-gracefully.
|
# exit all workers non-gracefully.
|
||||||
rpc.shutdown(graceful=False)
|
rpc.shutdown(graceful=False)
|
||||||
|
|
||||||
|
@dist_init
|
||||||
|
def test_deadlock(self):
|
||||||
|
# this test is copied from https://github.com/pytorch/pytorch/issues/45089
|
||||||
|
if self.rank == 1:
|
||||||
|
dst1 = worker_name((self.rank + 1) % self.world_size)
|
||||||
|
x = torch.ones(2)
|
||||||
|
y = torch.ones(2)
|
||||||
|
rpc.rpc_async(dst1, RpcTest._slow_add, args=(x, y), timeout=15).wait()
|
||||||
|
|
||||||
|
dist_initialized = dist.is_initialized()
|
||||||
|
if not dist_initialized:
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="gloo",
|
||||||
|
init_method=self.init_method,
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
)
|
||||||
|
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_local_shutdown_with_rpc(self):
|
def test_local_shutdown_with_rpc(self):
|
||||||
# test that we can start RPC, send RPCs, and then run local shutdown.
|
# test that we can start RPC, send RPCs, and then run local shutdown.
|
||||||
|
Reference in New Issue
Block a user