Let rpc._all_gather use default RPC timeout (#44983)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44983

`_all_gather` was converted from `_wait_all_workers` and inherited its
5 seconds fixed timeout. As `_all_gather` meant to support a broader
set of use cases, the timeout configuration should be more flexible.
This PR makes `rpc._all_gather` use the global default RPC timeout.

Test Plan: Imported from OSS

Reviewed By: pritamdamania87

Differential Revision: D23794383

Pulled By: mrshenli

fbshipit-source-id: 382f52c375f0f25c032c5abfc910f72baf4c5ad9
This commit is contained in:
Shen Li
2020-09-23 08:04:27 -07:00
committed by Facebook GitHub Bot
parent e5bade7b2c
commit 94c3cdd994
8 changed files with 42 additions and 27 deletions

View File

@ -27,7 +27,6 @@ namespace rpc {
namespace {
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
constexpr float kSecToMsConversion = 1000;
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;

View File

@ -8,12 +8,6 @@
namespace torch {
namespace distributed {
namespace rpc {
const std::string kRPCTimeoutErrorStr =
"RPC ran for more than {} milliseconds and timed out.";
namespace {
constexpr auto kSecToMsConversion = 1000;
}
////////////////////////// MessageCounter /////////////////////////////////
@ -802,7 +796,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
for (const auto& timedOutFuture : timedOutFutures) {
auto errStr =
fmt::format(kRPCTimeoutErrorStr, timedOutFuture.timeout_.count());
fmt::format(kRpcTimeoutErrorStr, timedOutFuture.timeout_.count());
auto err = makeRPCError(errStr, RPCErrorType::TIMEOUT);
if (!timedOutFuture.future_->hasError()) {

View File

@ -17,6 +17,9 @@ constexpr float kDefaultRpcTimeoutSeconds = 60;
// timeout for RPCs.
constexpr float kUnsetRpcTimeout = -1;
constexpr auto kDefaultInitMethod = "env://";
constexpr float kSecToMsConversion = 1000;
constexpr auto kRpcTimeoutErrorStr =
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;

View File

@ -22,16 +22,12 @@ namespace {
const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
const std::string kDefaultUvAddress = "127.0.0.1";
constexpr long kToMilliseconds = 1000;
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
const std::string kThreadPoolSize = "agent.thread_pool_size";
const std::string kNumIdleThreads = "agent.num_idle_threads";
const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
const std::string kRpcTimeoutErrorStr =
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
inline void checkCPUTensor(const torch::Tensor& tensor) {
TORCH_CHECK(
@ -273,7 +269,7 @@ TensorPipeAgent::TensorPipeAgent(
WorkerInfo(std::move(selfName), selfId),
std::move(cb),
std::chrono::milliseconds(
(long)(opts.rpcTimeoutSeconds * kToMilliseconds))),
(long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
opts_(std::move(opts)),
threadPool_(opts_.numWorkerThreads),
context_(std::make_shared<tensorpipe::Context>(
@ -685,7 +681,7 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
? getRpcTimeout()
: std::chrono::milliseconds(
static_cast<int>(rpcTimeoutSeconds * kToMilliseconds));
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
// We only add to the timeoutMap_ if the timeout is not 0. Per our
// documentation, a user-provided timeout of 0 indicates the RPC should never

View File

@ -6,10 +6,6 @@ namespace torch {
namespace distributed {
namespace rpc {
namespace {
constexpr auto kSecToMsConversion = 1000;
}
std::string fromVec(const std::vector<char>& vec) {
return std::string(vec.begin(), vec.end());
}

View File

@ -12,6 +12,7 @@ from . import (
PyRRef,
RemoteProfilerManager,
WorkerInfo,
get_rpc_timeout,
_cleanup_python_rpc_handler,
_delete_all_user_and_unforked_owner_rrefs,
_destroy_rref_context,
@ -34,7 +35,7 @@ from .internal import (
_build_rpc_profiling_key,
)
from .constants import UNSET_RPC_TIMEOUT
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
logger = logging.getLogger(__name__)
@ -142,7 +143,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
@_require_initialized
def _all_gather(obj):
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
r"""
This is similar to torch.distributed.all_gather(), but is using RPC. It
picks the worker with the smallest name (alphabetic order) as the leader.
@ -163,8 +164,8 @@ def _all_gather(obj):
_all_gather_sequence_id += 1
is_leader = leader_name == self_name
# Set a long enough timeout for all shutdown messages to be processed.
timeout = 5 # second
if timeout == UNSET_RPC_TIMEOUT:
timeout = get_rpc_timeout()
# Phase 1: Followers send it's object to the leader
if is_leader:
@ -178,9 +179,7 @@ def _all_gather(obj):
)
with _all_gather_dict_lock:
states = _all_gather_sequence_id_to_states[
sequence_id
]
states = _all_gather_sequence_id_to_states[sequence_id]
states.proceed_signal.wait()
# Phase 2: Leader broadcast gathered results to all followers
@ -207,7 +206,7 @@ def _all_gather(obj):
if errors:
raise RuntimeError(
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
f"after {timeout} seconds. The first exception is {errors[0][1]}"
f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
)
return states.gathered_objects
@ -223,7 +222,7 @@ def _wait_all_workers():
framework will work after this method returns.
"""
try:
_all_gather(None)
_all_gather(None, timeout=DEFAULT_SHUTDOWN_TIMEOUT)
except RuntimeError as ex:
logger.error(
f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"

View File

@ -12,6 +12,7 @@ from . import (
# For any RpcAgent.
DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC
DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
DEFAULT_SHUTDOWN_TIMEOUT = 5.0
# For ProcessGroupAgent.
DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS

View File

@ -131,6 +131,19 @@ class MyPickleClass:
self.t = val
class SlowPickleClass:
def __init__(self, t):
self.t = t
def __getstate__(self):
time.sleep(self.t)
return (self.t, )
def __setstate__(self, obj):
self.t = obj[0]
time.sleep(self.t)
class MyClass:
def __init__(self, a):
self.a = a
@ -931,6 +944,20 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(expected, results)
@dist_init
def test_all_gather_timeout(self):
rpc._set_rpc_timeout(0.1)
if self.rank == 0:
with self.assertRaisesRegex(
RuntimeError,
"timed out in _all_gather after 0\\.10 seconds"
):
rpc.api._all_gather(SlowPickleClass(0.5))
else:
with self.assertRaisesRegex(RuntimeError, "timeout.*100 ms"):
rpc.api._all_gather(SlowPickleClass(0.5))
@dist_init
def test_graceful_shutdown_with_uneven_workload(self):
"""Test graceful termination."""