mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Let rpc._all_gather use default RPC timeout (#44983)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44983 `_all_gather` was converted from `_wait_all_workers` and inherited its 5 seconds fixed timeout. As `_all_gather` meant to support a broader set of use cases, the timeout configuration should be more flexible. This PR makes `rpc._all_gather` use the global default RPC timeout. Test Plan: Imported from OSS Reviewed By: pritamdamania87 Differential Revision: D23794383 Pulled By: mrshenli fbshipit-source-id: 382f52c375f0f25c032c5abfc910f72baf4c5ad9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5bade7b2c
commit
94c3cdd994
@ -27,7 +27,6 @@ namespace rpc {
|
||||
namespace {
|
||||
|
||||
constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
|
||||
constexpr float kSecToMsConversion = 1000;
|
||||
|
||||
template <typename T>
|
||||
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
|
||||
|
@ -8,12 +8,6 @@
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
const std::string kRPCTimeoutErrorStr =
|
||||
"RPC ran for more than {} milliseconds and timed out.";
|
||||
|
||||
namespace {
|
||||
constexpr auto kSecToMsConversion = 1000;
|
||||
}
|
||||
|
||||
////////////////////////// MessageCounter /////////////////////////////////
|
||||
|
||||
@ -802,7 +796,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
|
||||
for (const auto& timedOutFuture : timedOutFutures) {
|
||||
auto errStr =
|
||||
fmt::format(kRPCTimeoutErrorStr, timedOutFuture.timeout_.count());
|
||||
fmt::format(kRpcTimeoutErrorStr, timedOutFuture.timeout_.count());
|
||||
auto err = makeRPCError(errStr, RPCErrorType::TIMEOUT);
|
||||
|
||||
if (!timedOutFuture.future_->hasError()) {
|
||||
|
@ -17,6 +17,9 @@ constexpr float kDefaultRpcTimeoutSeconds = 60;
|
||||
// timeout for RPCs.
|
||||
constexpr float kUnsetRpcTimeout = -1;
|
||||
constexpr auto kDefaultInitMethod = "env://";
|
||||
constexpr float kSecToMsConversion = 1000;
|
||||
constexpr auto kRpcTimeoutErrorStr =
|
||||
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
|
||||
|
||||
using steady_clock_time_point =
|
||||
std::chrono::time_point<std::chrono::steady_clock>;
|
||||
|
@ -22,16 +22,12 @@ namespace {
|
||||
const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
|
||||
const std::string kDefaultUvAddress = "127.0.0.1";
|
||||
|
||||
constexpr long kToMilliseconds = 1000;
|
||||
|
||||
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
|
||||
const std::string kThreadPoolSize = "agent.thread_pool_size";
|
||||
const std::string kNumIdleThreads = "agent.num_idle_threads";
|
||||
const std::string kClientActiveCalls = "agent.client_active_calls";
|
||||
const std::string kServerActiveCalls = "agent.server_active_calls";
|
||||
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
||||
const std::string kRpcTimeoutErrorStr =
|
||||
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
|
||||
|
||||
inline void checkCPUTensor(const torch::Tensor& tensor) {
|
||||
TORCH_CHECK(
|
||||
@ -273,7 +269,7 @@ TensorPipeAgent::TensorPipeAgent(
|
||||
WorkerInfo(std::move(selfName), selfId),
|
||||
std::move(cb),
|
||||
std::chrono::milliseconds(
|
||||
(long)(opts.rpcTimeoutSeconds * kToMilliseconds))),
|
||||
(long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
|
||||
opts_(std::move(opts)),
|
||||
threadPool_(opts_.numWorkerThreads),
|
||||
context_(std::make_shared<tensorpipe::Context>(
|
||||
@ -685,7 +681,7 @@ std::shared_ptr<FutureMessage> TensorPipeAgent::send(
|
||||
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
|
||||
? getRpcTimeout()
|
||||
: std::chrono::milliseconds(
|
||||
static_cast<int>(rpcTimeoutSeconds * kToMilliseconds));
|
||||
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
||||
|
||||
// We only add to the timeoutMap_ if the timeout is not 0. Per our
|
||||
// documentation, a user-provided timeout of 0 indicates the RPC should never
|
||||
|
@ -6,10 +6,6 @@ namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
constexpr auto kSecToMsConversion = 1000;
|
||||
}
|
||||
|
||||
std::string fromVec(const std::vector<char>& vec) {
|
||||
return std::string(vec.begin(), vec.end());
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ from . import (
|
||||
PyRRef,
|
||||
RemoteProfilerManager,
|
||||
WorkerInfo,
|
||||
get_rpc_timeout,
|
||||
_cleanup_python_rpc_handler,
|
||||
_delete_all_user_and_unforked_owner_rrefs,
|
||||
_destroy_rref_context,
|
||||
@ -34,7 +35,7 @@ from .internal import (
|
||||
_build_rpc_profiling_key,
|
||||
)
|
||||
|
||||
from .constants import UNSET_RPC_TIMEOUT
|
||||
from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -142,7 +143,7 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def _all_gather(obj):
|
||||
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
|
||||
r"""
|
||||
This is similar to torch.distributed.all_gather(), but is using RPC. It
|
||||
picks the worker with the smallest name (alphabetic order) as the leader.
|
||||
@ -163,8 +164,8 @@ def _all_gather(obj):
|
||||
_all_gather_sequence_id += 1
|
||||
|
||||
is_leader = leader_name == self_name
|
||||
# Set a long enough timeout for all shutdown messages to be processed.
|
||||
timeout = 5 # second
|
||||
if timeout == UNSET_RPC_TIMEOUT:
|
||||
timeout = get_rpc_timeout()
|
||||
|
||||
# Phase 1: Followers send it's object to the leader
|
||||
if is_leader:
|
||||
@ -178,9 +179,7 @@ def _all_gather(obj):
|
||||
)
|
||||
|
||||
with _all_gather_dict_lock:
|
||||
states = _all_gather_sequence_id_to_states[
|
||||
sequence_id
|
||||
]
|
||||
states = _all_gather_sequence_id_to_states[sequence_id]
|
||||
states.proceed_signal.wait()
|
||||
|
||||
# Phase 2: Leader broadcast gathered results to all followers
|
||||
@ -207,7 +206,7 @@ def _all_gather(obj):
|
||||
if errors:
|
||||
raise RuntimeError(
|
||||
f"Followers {[e[0] for e in errors]} timed out in _all_gather "
|
||||
f"after {timeout} seconds. The first exception is {errors[0][1]}"
|
||||
f"after {timeout:.2f} seconds. The first exception is {errors[0][1]}"
|
||||
)
|
||||
|
||||
return states.gathered_objects
|
||||
@ -223,7 +222,7 @@ def _wait_all_workers():
|
||||
framework will work after this method returns.
|
||||
"""
|
||||
try:
|
||||
_all_gather(None)
|
||||
_all_gather(None, timeout=DEFAULT_SHUTDOWN_TIMEOUT)
|
||||
except RuntimeError as ex:
|
||||
logger.error(
|
||||
f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
|
||||
|
@ -12,6 +12,7 @@ from . import (
|
||||
# For any RpcAgent.
|
||||
DEFAULT_RPC_TIMEOUT_SEC = _DEFAULT_RPC_TIMEOUT_SEC
|
||||
DEFAULT_INIT_METHOD = _DEFAULT_INIT_METHOD
|
||||
DEFAULT_SHUTDOWN_TIMEOUT = 5.0
|
||||
|
||||
# For ProcessGroupAgent.
|
||||
DEFAULT_NUM_SEND_RECV_THREADS = _DEFAULT_NUM_SEND_RECV_THREADS
|
||||
|
@ -131,6 +131,19 @@ class MyPickleClass:
|
||||
self.t = val
|
||||
|
||||
|
||||
class SlowPickleClass:
|
||||
def __init__(self, t):
|
||||
self.t = t
|
||||
|
||||
def __getstate__(self):
|
||||
time.sleep(self.t)
|
||||
return (self.t, )
|
||||
|
||||
def __setstate__(self, obj):
|
||||
self.t = obj[0]
|
||||
time.sleep(self.t)
|
||||
|
||||
|
||||
class MyClass:
|
||||
def __init__(self, a):
|
||||
self.a = a
|
||||
@ -931,6 +944,20 @@ class RpcTest(RpcAgentTestFixture):
|
||||
|
||||
self.assertEqual(expected, results)
|
||||
|
||||
@dist_init
|
||||
def test_all_gather_timeout(self):
|
||||
rpc._set_rpc_timeout(0.1)
|
||||
|
||||
if self.rank == 0:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"timed out in _all_gather after 0\\.10 seconds"
|
||||
):
|
||||
rpc.api._all_gather(SlowPickleClass(0.5))
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, "timeout.*100 ms"):
|
||||
rpc.api._all_gather(SlowPickleClass(0.5))
|
||||
|
||||
@dist_init
|
||||
def test_graceful_shutdown_with_uneven_workload(self):
|
||||
"""Test graceful termination."""
|
||||
|
Reference in New Issue
Block a user