Add optional timeout argument for RpcAgent join() (#76194)

Summary:
This PR was created to resolve issue brought up in https://fb.workplace.com/groups/319878845696681/permalink/741428653541696/

Changes:
- Adds timeout argument to RpcAgent.join()
- Add optional timeout argument to ThriftRpcAgent barrier()
- During shutdown (ThriftRpcAgent join) calls the barrier, the agent will use the timeout passed to shutdown and pass that timeout into the join().
- Update API.py to also include fix bug (missing timeout for signal)
- Change default shutdown timeout to 0 (no timeout). Existing functionality in _all_gather will remain the same and wait indefinitely for signal if no timeout is set for the function. New functionality has user specify timeout for both the signal and rpc calls.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76194

Test Plan:
Modified barrier test

buck test torch/fb/distributed/thriftRpcBackend/test:ThriftRpcAgentTest -- BarrierTest

Reviewed By: mrshenli

Differential Revision: D35825382

fbshipit-source-id: e91e9ab5d9fca08787cb6b6b8125a4b03d1c7cde
(cherry picked from commit fcf899a387001574bf4e39a213ea741611d76097)
This commit is contained in:
Howard Huang
2022-05-02 18:05:19 -07:00
committed by PyTorch MergeBot
parent b34739fbef
commit e68686bb05
8 changed files with 33 additions and 16 deletions

View File

@ -33,7 +33,7 @@ class WorkerInfo:
def __repr__(self) -> str: ...
class RpcAgent:
def join(self, shutdown: bool = False): ...
def join(self, shutdown: bool = False, timeout: float = 0): ...
def sync(self): ...
def shutdown(self): ...
@overload
@ -92,7 +92,7 @@ class TensorPipeAgent(RpcAgent):
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
devices: List[torch.device],
): ...
def join(self, shutdown: bool = False): ...
def join(self, shutdown: bool = False, timeout: float = 0): ...
def shutdown(self): ...
@overload
def get_worker_info(self) -> WorkerInfo: ...

View File

@ -137,7 +137,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
"join",
&RpcAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false)
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
.def(
@ -603,7 +604,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
"join",
&TensorPipeAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false)
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"shutdown",
&TensorPipeAgent::shutdown,

View File

@ -213,7 +213,7 @@ class TORCH_API RpcAgent {
// Call sync and join all internal threads. This method should be called
// before every RPC process exits.
virtual void join(bool shutdown = false) = 0;
virtual void join(bool shutdown = false, float timeout = 0) = 0;
// Synchronize the this process with other ``RpcAgent`` processes. Block until
// all ``RpcAgent``s reach this method and send all pending messages.

View File

@ -1078,7 +1078,7 @@ void TensorPipeAgent::leaveGroup() {
}
// TODO: Remove join()
void TensorPipeAgent::join(bool shutdown) {
void TensorPipeAgent::join(bool shutdown, float /* unused */) {
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
if (!isStaticGroup_) {
leaveGroup();

View File

@ -182,7 +182,7 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
// join() and sync() would be deprecated -
// https://github.com/pytorch/pytorch/issues/27647
void join(bool shutdown = false) override;
void join(bool shutdown = false, float timeout = 0) override;
void sync() override{};
void startImpl() override;
void shutdownImpl() override;

View File

@ -98,7 +98,8 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
"join",
&TensorPipeAgent::join,
py::call_guard<py::gil_scoped_release>(),
py::arg("shutdown") = false)
py::arg("shutdown") = false,
py::arg("timeout") = 0)
.def(
"shutdown",
&TensorPipeAgent::shutdown,

View File

@ -200,8 +200,20 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
sequence_id = concat_names + str(sequence_num)
is_leader = leader_name == self_name
if timeout == UNSET_RPC_TIMEOUT:
timeout = get_rpc_timeout()
# Timeout is specified by agent for RPC calls
rpc_timeout = get_rpc_timeout()
# No timeout for signal
signal_timeout = None
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
# No timeout for RPC
rpc_timeout = timeout
# No timeout for signal
signal_timeout = None
else:
# Signal and RPC timeout use the same timeout
signal_timeout = rpc_timeout = timeout
# Phase 1: Followers send it's object to the leader
if is_leader:
@ -211,12 +223,14 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
leader_name,
_gather_to_leader,
args=(sequence_id, self_name, obj, worker_names),
timeout=timeout,
timeout=rpc_timeout,
)
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[sequence_id]
states.proceed_signal.wait()
# Timeout is either set by function parameter or None (which is indefinite)
states.proceed_signal.wait(timeout=signal_timeout)
# Phase 2: Leader broadcast gathered results to all followers
# Leader's signal is the first to be unblocked, after receiving all
@ -228,7 +242,7 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
follower_name,
_broadcast_to_followers,
args=(sequence_id, states.gathered_objects),
timeout=timeout
timeout=rpc_timeout
)
worker_name_to_response_future_dict[follower_name] = fut
@ -242,7 +256,7 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
if errors:
raise RuntimeError(
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
)
# Clean up for the states using the sequence_id
@ -343,7 +357,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
_wait_all_workers(timeout)
_delete_all_user_and_unforked_owner_rrefs()
agent.join(shutdown=True)
agent.join(shutdown=True, timeout=timeout)
else:
# This is a dynamic group so we need to grab the token for the operation
my_worker_info = agent.get_worker_info()
@ -353,7 +367,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
for worker in all_worker_infos:
if worker.name != my_name:
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
agent.join(shutdown=True)
agent.join(shutdown=True, timeout=timeout)
finally:
# In case of errors, continue to complete the local shutdown.
_finalize_shutdown()

View File

@ -11,7 +11,7 @@ from torch._C._distributed_rpc import (
# For any RpcAgent.
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0
DEFAULT_SHUTDOWN_TIMEOUT: float = 0
# For TensorPipeAgent.
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS