Add optional timeout argument for RpcAgent join() (#76194)

Summary:
This PR was created to resolve issue brought up in https://fb.workplace.com/groups/319878845696681/permalink/741428653541696/

Changes:
- Adds timeout argument to RpcAgent.join()
- Add optional timeout argument to ThriftRpcAgent barrier()
- During shutdown (ThriftRpcAgent join) calls the barrier, the agent will use the timeout passed to shutdown and pass that timeout into the join().
- Update API.py to also include fix bug (missing timeout for signal)
- Change default shutdown timeout to 0 (no timeout). Existing functionality in _all_gather will remain the same and wait indefinitely for signal if no timeout is set for the function. New functionality has user specify timeout for both the signal and rpc calls.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76194

Test Plan:
Modified barrier test

buck test torch/fb/distributed/thriftRpcBackend/test:ThriftRpcAgentTest -- BarrierTest

Reviewed By: mrshenli

Differential Revision: D35825382

fbshipit-source-id: e91e9ab5d9fca08787cb6b6b8125a4b03d1c7cde
(cherry picked from commit fcf899a387001574bf4e39a213ea741611d76097)
This commit is contained in:
Howard Huang
2022-05-02 18:05:19 -07:00
committed by PyTorch MergeBot
parent b34739fbef
commit e68686bb05
8 changed files with 33 additions and 16 deletions

View File

@ -11,7 +11,7 @@ from torch._C._distributed_rpc import (
# For any RpcAgent.
DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT: float = 5.0
DEFAULT_SHUTDOWN_TIMEOUT: float = 0
# For TensorPipeAgent.
DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS