mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add optional timeout argument for RpcAgent join() (#76194)
Summary: This PR was created to resolve issue brought up in https://fb.workplace.com/groups/319878845696681/permalink/741428653541696/ Changes: - Adds timeout argument to RpcAgent.join() - Add optional timeout argument to ThriftRpcAgent barrier() - During shutdown (ThriftRpcAgent join) calls the barrier, the agent will use the timeout passed to shutdown and pass that timeout into the join(). - Update API.py to also include fix bug (missing timeout for signal) - Change default shutdown timeout to 0 (no timeout). Existing functionality in _all_gather will remain the same and wait indefinitely for signal if no timeout is set for the function. New functionality has user specify timeout for both the signal and rpc calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76194 Test Plan: Modified barrier test buck test torch/fb/distributed/thriftRpcBackend/test:ThriftRpcAgentTest -- BarrierTest Reviewed By: mrshenli Differential Revision: D35825382 fbshipit-source-id: e91e9ab5d9fca08787cb6b6b8125a4b03d1c7cde (cherry picked from commit fcf899a387001574bf4e39a213ea741611d76097)
This commit is contained in:
committed by
PyTorch MergeBot
parent
b34739fbef
commit
e68686bb05
@ -33,7 +33,7 @@ class WorkerInfo:
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
class RpcAgent:
|
||||
def join(self, shutdown: bool = False): ...
|
||||
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
||||
def sync(self): ...
|
||||
def shutdown(self): ...
|
||||
@overload
|
||||
@ -92,7 +92,7 @@ class TensorPipeAgent(RpcAgent):
|
||||
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
|
||||
devices: List[torch.device],
|
||||
): ...
|
||||
def join(self, shutdown: bool = False): ...
|
||||
def join(self, shutdown: bool = False, timeout: float = 0): ...
|
||||
def shutdown(self): ...
|
||||
@overload
|
||||
def get_worker_info(self) -> WorkerInfo: ...
|
||||
|
@ -137,7 +137,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
||||
"join",
|
||||
&RpcAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
py::arg("shutdown") = false)
|
||||
py::arg("shutdown") = false,
|
||||
py::arg("timeout") = 0)
|
||||
.def(
|
||||
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
@ -603,7 +604,8 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
||||
"join",
|
||||
&TensorPipeAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
py::arg("shutdown") = false)
|
||||
py::arg("shutdown") = false,
|
||||
py::arg("timeout") = 0)
|
||||
.def(
|
||||
"shutdown",
|
||||
&TensorPipeAgent::shutdown,
|
||||
|
@ -213,7 +213,7 @@ class TORCH_API RpcAgent {
|
||||
|
||||
// Call sync and join all internal threads. This method should be called
|
||||
// before every RPC process exits.
|
||||
virtual void join(bool shutdown = false) = 0;
|
||||
virtual void join(bool shutdown = false, float timeout = 0) = 0;
|
||||
|
||||
// Synchronize the this process with other ``RpcAgent`` processes. Block until
|
||||
// all ``RpcAgent``s reach this method and send all pending messages.
|
||||
|
@ -1078,7 +1078,7 @@ void TensorPipeAgent::leaveGroup() {
|
||||
}
|
||||
|
||||
// TODO: Remove join()
|
||||
void TensorPipeAgent::join(bool shutdown) {
|
||||
void TensorPipeAgent::join(bool shutdown, float /* unused */) {
|
||||
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
|
||||
if (!isStaticGroup_) {
|
||||
leaveGroup();
|
||||
|
@ -182,7 +182,7 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
|
||||
|
||||
// join() and sync() would be deprecated -
|
||||
// https://github.com/pytorch/pytorch/issues/27647
|
||||
void join(bool shutdown = false) override;
|
||||
void join(bool shutdown = false, float timeout = 0) override;
|
||||
void sync() override{};
|
||||
void startImpl() override;
|
||||
void shutdownImpl() override;
|
||||
|
@ -98,7 +98,8 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
||||
"join",
|
||||
&TensorPipeAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
py::arg("shutdown") = false)
|
||||
py::arg("shutdown") = false,
|
||||
py::arg("timeout") = 0)
|
||||
.def(
|
||||
"shutdown",
|
||||
&TensorPipeAgent::shutdown,
|
||||
|
@ -200,8 +200,20 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
sequence_id = concat_names + str(sequence_num)
|
||||
|
||||
is_leader = leader_name == self_name
|
||||
|
||||
if timeout == UNSET_RPC_TIMEOUT:
|
||||
timeout = get_rpc_timeout()
|
||||
# Timeout is specified by agent for RPC calls
|
||||
rpc_timeout = get_rpc_timeout()
|
||||
# No timeout for signal
|
||||
signal_timeout = None
|
||||
elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
|
||||
# No timeout for RPC
|
||||
rpc_timeout = timeout
|
||||
# No timeout for signal
|
||||
signal_timeout = None
|
||||
else:
|
||||
# Signal and RPC timeout use the same timeout
|
||||
signal_timeout = rpc_timeout = timeout
|
||||
|
||||
# Phase 1: Followers send it's object to the leader
|
||||
if is_leader:
|
||||
@ -211,12 +223,14 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
leader_name,
|
||||
_gather_to_leader,
|
||||
args=(sequence_id, self_name, obj, worker_names),
|
||||
timeout=timeout,
|
||||
timeout=rpc_timeout,
|
||||
)
|
||||
|
||||
with _all_gather_dict_lock:
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
states.proceed_signal.wait()
|
||||
|
||||
# Timeout is either set by function parameter or None (which is indefinite)
|
||||
states.proceed_signal.wait(timeout=signal_timeout)
|
||||
|
||||
# Phase 2: Leader broadcast gathered results to all followers
|
||||
# Leader's signal is the first to be unblocked, after receiving all
|
||||
@ -228,7 +242,7 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
follower_name,
|
||||
_broadcast_to_followers,
|
||||
args=(sequence_id, states.gathered_objects),
|
||||
timeout=timeout
|
||||
timeout=rpc_timeout
|
||||
)
|
||||
worker_name_to_response_future_dict[follower_name] = fut
|
||||
|
||||
@ -242,7 +256,7 @@ def _all_gather(obj, worker_names=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
if errors:
|
||||
raise RuntimeError(
|
||||
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
|
||||
f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
|
||||
f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
|
||||
)
|
||||
|
||||
# Clean up for the states using the sequence_id
|
||||
@ -343,7 +357,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
|
||||
_wait_all_workers(timeout)
|
||||
_delete_all_user_and_unforked_owner_rrefs()
|
||||
agent.join(shutdown=True)
|
||||
agent.join(shutdown=True, timeout=timeout)
|
||||
else:
|
||||
# This is a dynamic group so we need to grab the token for the operation
|
||||
my_worker_info = agent.get_worker_info()
|
||||
@ -353,7 +367,7 @@ def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
|
||||
for worker in all_worker_infos:
|
||||
if worker.name != my_name:
|
||||
rpc_sync(worker.name, _update_group_membership, args=(my_worker_info, [], {}, False))
|
||||
agent.join(shutdown=True)
|
||||
agent.join(shutdown=True, timeout=timeout)
|
||||
finally:
|
||||
# In case of errors, continue to complete the local shutdown.
|
||||
_finalize_shutdown()
|
||||
|
@ -11,7 +11,7 @@ from torch._C._distributed_rpc import (
|
||||
# For any RpcAgent.
|
||||
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
|
||||
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
|
||||
DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0
|
||||
DEFAULT_SHUTDOWN_TIMEOUT: float = 0
|
||||
|
||||
# For TensorPipeAgent.
|
||||
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS
|
||||
|
Reference in New Issue
Block a user