Add local shutdown to process group agent (#30330)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30330

This is now possible due to previous changes made in `gloo` and `ProcessGroupGloo`. We `abort` the listener thread that is waiting for a message, and join all other threads. The API is changed so that the previous `wait_all_workers` does not destroy the agent, and this is now done in a new `shutdown` method. All callsites are updated appropriately.

ghstack-source-id: 94673884
ghstack-source-id: 94673884

Test Plan: Unit tests pass.

Reviewed By: mrshenli

Differential Revision: D18661775

fbshipit-source-id: 5aaa7c14603e18253394224994f6cd43234301c2
This commit is contained in:
Rohan Varma
2019-11-27 22:31:38 -08:00
committed by Facebook Github Bot
parent 7ac8efa689
commit 1350b99de4
10 changed files with 230 additions and 94 deletions

View File

@ -314,13 +314,14 @@ The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows:
Simple end to end example Simple end to end example
^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^
Putting it all together, a very simple end to end example using distributed Putting it all together, the following is a simple end to end example using
autograd and distributed optimizer is as follows: distributed autograd and the distributed optimizer. If the code is placed into a
file called "dist_autograd_simple.py", it can be run with the command
:code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`:
.. code:: .. code::
import multiprocessing as mp import multiprocessing as mp
from tempfile import NamedTemporaryFile
import torch import torch
import torch.distributed.autograd as dist_autograd import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc from torch.distributed import rpc
@ -330,50 +331,52 @@ autograd and distributed optimizer is as follows:
def random_tensor(): def random_tensor():
return torch.rand((3, 3), requires_grad=True) return torch.rand((3, 3), requires_grad=True)
def _run_process(self_rank, dst_rank): def _run_process(rank, dst_rank, world_size):
self_name = "worker{}".format(self_rank) name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank) dst_name = "worker{}".format(dst_rank)
# Initialize RPC. # Initialize RPC.
rpc.init_rpc( rpc.init_rpc(
self_name=self_name, name=name,
self_rank=self_rank, rank=rank,
worker_name_to_id={"worker0": 0, "worker1": 1}, world_size=world_size
) )
# Use a distributed autograd context. # Use a distributed autograd context.
with dist_autograd.context() as context_id: with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes). # Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor) rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor) rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here() loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd). # Backward pass (run distributed autograd).
dist_autograd.backward([loss.sum()]) dist_autograd.backward([loss.sum()])
# Build DistributedOptimizer. # Build DistributedOptimizer.
dist_optim = DistributedOptimizer( dist_optim = DistributedOptimizer(
optim.SGD, optim.SGD,
[rref1, rref2], [rref1, rref2],
lr=0.05, lr=0.05,
) )
# Run the distributed optimizer step. # Run the distributed optimizer step.
dist_optim.step() dist_optim.step()
def run_process(self_rank, dst_rank): def run_process(rank, dst_rank, world_size):
_run_process(self_rank, dst_rank) _run_process(rank, dst_rank, world_size)
rpc.wait_all_workers() rpc.shutdown()
processes = [] processes = []
# Run two workers. # Run world_size workers.
for i in range(2): world_size = 2
p = mp.Process(target=run_process, args=(i, (i + 1) % 2)) for i in range(world_size):
p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
p.start() p.start()
processes.append(p) processes.append(p)
for p in processes: for p in processes:
p.join() p.join()
.. _RFC: https://github.com/pytorch/pytorch/issues/23110 .. _RFC: https://github.com/pytorch/pytorch/issues/23110

View File

@ -55,7 +55,7 @@ This library provides primitives allowing users to create and modify references
.. autofunction:: rpc_async .. autofunction:: rpc_async
.. autofunction:: remote .. autofunction:: remote
.. autofunction:: get_worker_info .. autofunction:: get_worker_info
.. autofunction:: wait_all_workers .. autofunction:: shutdown
Distributed Autograd Framework Distributed Autograd Framework
------------------------------ ------------------------------

View File

@ -1360,7 +1360,7 @@ class DistAutogradTest(RpcAgentTestFixture):
# receive gradients from the node that received an error (and as a # receive gradients from the node that received an error (and as a
# result it didn't execute the rest of the graph). # result it didn't execute the rest of the graph).
dist.barrier() dist.barrier()
rpc.wait_all_workers() rpc.shutdown()
sys.exit(0) sys.exit(0)

View File

@ -127,7 +127,7 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
# since we need to shutdown the RPC agent. If we don't shutdown the # since we need to shutdown the RPC agent. If we don't shutdown the
# RPC agent, tests would fail since RPC agent threads, locks and # RPC agent, tests would fail since RPC agent threads, locks and
# condition variables are not properly terminated. # condition variables are not properly terminated.
rpc.wait_all_workers() rpc.shutdown()
return return_value return return_value

View File

@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import concurrent.futures import concurrent.futures
from datetime import timedelta from datetime import timedelta
import sys import sys
import time
import unittest import unittest
from collections import namedtuple from collections import namedtuple
from unittest import mock from unittest import mock
@ -18,6 +19,21 @@ from torch.distributed.rpc.api import _use_rpc_pickler
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
from rpc_agent_test_fixture import RpcAgentTestFixture from rpc_agent_test_fixture import RpcAgentTestFixture
rpc_done = [False, False, False, False]
# TODO: dedupe this with the code in dist_autograd_test.py.
# Send rpc done info and context_id to
# dst_rank = (self.rank + rank_distance) % self.world_size
# we don't need a lock here since the GIL is held while executing remote
# python UDFs, so access is serialized across several workers.
def _set_rpc_done(rank_distance):
global rpc_done
rpc_done[rank_distance] = True
def _check_rpc_done(rank_distance):
while not rpc_done[rank_distance]:
# yield control to other threads
time.sleep(0)
def requires_process_group_agent(message=""): def requires_process_group_agent(message=""):
def decorator(old_func): def decorator(old_func):
@ -127,7 +143,6 @@ def my_tensor_function(a, b):
return a + b return a + b
def my_sleep_func(seconds=1): def my_sleep_func(seconds=1):
import time
time.sleep(seconds) time.sleep(seconds)
@ -346,7 +361,7 @@ class RpcTest(RpcAgentTestFixture):
world_size=self.world_size, world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options, rpc_backend_options=self.rpc_backend_options,
) )
rpc.wait_all_workers() rpc.shutdown()
@dist_init(setup_rpc=False) @dist_init(setup_rpc=False)
def test_reinit(self): def test_reinit(self):
@ -379,8 +394,9 @@ class RpcTest(RpcAgentTestFixture):
world_size=self.world_size, world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options, rpc_backend_options=self.rpc_backend_options,
) )
rpc.wait_all_workers() rpc.shutdown()
@unittest.skip("test_invalid_names is flaky, see https://github.com/pytorch/pytorch/issues/25912")
@dist_init(setup_rpc=False) @dist_init(setup_rpc=False)
def test_invalid_names(self): def test_invalid_names(self):
with self.assertRaisesRegex(RuntimeError, "Worker name must match"): with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
@ -447,8 +463,8 @@ class RpcTest(RpcAgentTestFixture):
from torch.distributed.rpc.api import _agent from torch.distributed.rpc.api import _agent
self.assertEqual(_agent, None) self.assertEqual(_agent, None)
# wait_all_workers() should not do anything as _agent is None # shutdown() should not do anything as _agent is None
rpc.wait_all_workers() rpc.shutdown()
# We need this barrier here because although init_process_group is # We need this barrier here because although init_process_group is
# blocking, it does not guarantee that all ranks are done with # blocking, it does not guarantee that all ranks are done with
# initialization after the call. We did run into issues with it where # initialization after the call. We did run into issues with it where
@ -525,7 +541,7 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(ret, torch.ones(n, n) * 2) self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init(setup_rpc=False) @dist_init(setup_rpc=False)
def test_wait_all_workers(self): def test_shutdown(self):
# Initialize RPC. # Initialize RPC.
rpc.init_rpc( rpc.init_rpc(
name="worker%d" % self.rank, name="worker%d" % self.rank,
@ -543,7 +559,7 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)), args=(torch.ones(n, n), torch.ones(n, n)),
) )
self.assertEqual(ret, torch.ones(n, n) * 2) self.assertEqual(ret, torch.ones(n, n) * 2)
rpc.wait_all_workers() rpc.shutdown()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"): with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
rpc.rpc_sync( rpc.rpc_sync(
@ -552,8 +568,8 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)), args=(torch.ones(n, n), torch.ones(n, n)),
) )
# it's safe to call wait_all_workers() multiple times # it's safe to call shutdown() multiple times
rpc.wait_all_workers() rpc.shutdown()
@dist_init @dist_init
def test_expected_src(self): def test_expected_src(self):
@ -717,8 +733,6 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(ret, torch.ones(2, 2) + 1) self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()): def _stress_test_rpc(self, f, repeat=1000, args=()):
import time
n = self.rank + 1 n = self.rank + 1
dst_rank = n % self.world_size dst_rank = n % self.world_size
futs = [] futs = []
@ -1138,7 +1152,7 @@ class RpcTest(RpcAgentTestFixture):
import torch.distributed.rpc.api as api import torch.distributed.rpc.api as api
api._ignore_rref_leak = True api._ignore_rref_leak = True
rpc.wait_all_workers() rpc.shutdown()
@dist_init(setup_rpc=False) @dist_init(setup_rpc=False)
def test_rref_leak(self): def test_rref_leak(self):
@ -1149,6 +1163,68 @@ class RpcTest(RpcAgentTestFixture):
def test_ignore_rref_leak(self): def test_ignore_rref_leak(self):
self._test_rref_leak(ignore_leak=True) self._test_rref_leak(ignore_leak=True)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown(self):
# test that we can start RPC and then immediately locally shutdown
# without sending any messages.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown_with_rpc(self):
# test that we can start RPC, send RPCs, and then run local shutdown.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# wait for RPCs to be done, so that some workers don't try to shut down
# too early.
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(1,))
_check_rpc_done(1)
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_wait_all_workers_and_shutdown(self):
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
# called without errors being raised due to attempting to shut down
# multiple times.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options
)
from torch.distributed.rpc.api import _wait_all_workers
# intentional call to internal _wait_all_workers.
_wait_all_workers()
rpc.shutdown()
@dist_init(setup_rpc=False) @dist_init(setup_rpc=False)
def test_get_rpc_timeout(self): def test_get_rpc_timeout(self):
timeout = timedelta(seconds=1) timeout = timedelta(seconds=1)
@ -1167,7 +1243,7 @@ class RpcTest(RpcAgentTestFixture):
) )
set_timeout = rpc.get_rpc_timeout() set_timeout = rpc.get_rpc_timeout()
self.assertEqual(timeout, set_timeout) self.assertEqual(timeout, set_timeout)
rpc.wait_all_workers() rpc.shutdown()
@dist_init @dist_init
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip") @requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")

View File

@ -114,9 +114,9 @@ Otherwise, throws an exception.
return PyRRef::unpickle(t); return PyRRef::unpickle(t);
})); }));
// future.wait() should not be called after wait_all_workers(), e.g., // future.wait() should not be called after shutdown(), e.g.,
// pythonRpcHandler is cleaned up in wait_all_workers(), after // pythonRpcHandler is cleaned up in shutdown(), after
// wait_all_workers(), python objects returned from rpc python call can not be // shutdown(), python objects returned from rpc python call can not be
// resolved. // resolved.
auto futureMessage = auto futureMessage =
shared_ptr_class_<FutureMessage>(module, "FutureMessage") shared_ptr_class_<FutureMessage>(module, "FutureMessage")
@ -162,6 +162,10 @@ Otherwise, throws an exception.
"join", "join",
&ProcessGroupAgent::join, &ProcessGroupAgent::join,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&ProcessGroupAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def( .def(
"sync", "sync",
&ProcessGroupAgent::sync, &ProcessGroupAgent::sync,

View File

@ -127,7 +127,6 @@ ProcessGroupAgent::ProcessGroupAgent(
WorkerInfo(std::move(workerName), pg->getRank()), WorkerInfo(std::move(workerName), pg->getRank()),
c10::guts::make_unique<RequestCallbackImpl>(), c10::guts::make_unique<RequestCallbackImpl>(),
rpcTimeout), rpcTimeout),
shutdown_{false},
pg_(std::move(pg)), pg_(std::move(pg)),
sendCounts_(pg_->getSize()), sendCounts_(pg_->getSize()),
recvCounts_(pg_->getSize()), recvCounts_(pg_->getSize()),
@ -184,30 +183,12 @@ std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
} }
void ProcessGroupAgent::join() { void ProcessGroupAgent::join() {
// Every process i sends a SHUTDOWN message to process i + 1. This is
// necessary for now because:
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
// feed it a message or kill the thread.
// 2. A GLOO process cannot send message to itself. (there is an ongoing
// effort to fix this problem).
shutdown_.store(true);
sync(); sync();
// This is needed in case no futures were created, otherwise the future
// timeout watchdog would sleep forever.
futureTimeoutCV_.notify_one();
std::unique_lock<std::mutex> lock(futureMutex_); std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait( futureCV_.wait(
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); }); lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
lock.unlock(); lock.unlock();
pg_->barrier()->wait(); pg_->barrier()->wait();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueueSend(
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
threadPool_.waitWorkComplete();
listenerThread_.join();
futureTimeoutThread_.join();
PythonRpcHandler::getInstance().cleanup();
} }
bool ProcessGroupAgent::hasPendingMessage() { bool ProcessGroupAgent::hasPendingMessage() {
@ -273,14 +254,38 @@ void ProcessGroupAgent::sync() {
} }
void ProcessGroupAgent::start() { void ProcessGroupAgent::start() {
{
std::lock_guard<std::mutex> futureLock{futureMutex_};
rpcRunning_.store(true);
}
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this); listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
futureTimeoutThread_ = futureTimeoutThread_ =
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this); std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
} }
void ProcessGroupAgent::shutdown() {
LOG(INFO) << "Shutting down ProcessGroupAgent.";
std::unique_lock<std::mutex> lock{futureMutex_};
if (!rpcRunning_.exchange(false)) {
return;
}
lock.unlock();
futureTimeoutCV_.notify_one();
futureTimeoutThread_.join();
{
std::unique_lock<std::mutex> lock(recvWorkMutex_);
if (recvWork_) {
recvWork_->abort();
}
}
threadPool_.waitWorkComplete();
listenerThread_.join();
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send( std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerInfo& to, const WorkerInfo& to,
Message&& message) { Message&& message) {
TORCH_CHECK(rpcRunning_.load(), "ProcessGroupAgent hasn't started.")
TORCH_CHECK( TORCH_CHECK(
to.id_ < (worker_id_t)pg_->getSize(), to.id_ < (worker_id_t)pg_->getSize(),
"Destination rank is out of bound, got ", "Destination rank is out of bound, got ",
@ -460,10 +465,19 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
} }
void ProcessGroupAgent::listenLoop() { void ProcessGroupAgent::listenLoop() {
while (true) { while (rpcRunning_.load()) {
// rank, tensor size, message type // rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})}; std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
pg_->recvAnysource(preamble, pg_->getRank())->wait(); auto work = pg_->recvAnysource(preamble, pg_->getRank());
{
std::lock_guard<std::mutex> guard(recvWorkMutex_);
recvWork_ = work;
}
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
return;
}
int64_t* preamble_items = preamble.front().storage().data<int64_t>(); int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0]; auto srcRank = preamble_items[0];
@ -487,9 +501,12 @@ void ProcessGroupAgent::listenLoop() {
} }
void ProcessGroupAgent::pollTimedOutRPCs() { void ProcessGroupAgent::pollTimedOutRPCs() {
while (!shutdown_.load()) { while (true) {
std::chrono::milliseconds sleepTime;
std::unique_lock<std::mutex> lock{futureMutex_}; std::unique_lock<std::mutex> lock{futureMutex_};
if (!rpcRunning_.load()) {
return;
}
std::chrono::milliseconds sleepTime;
// Estimate amount of time the first future will time out in, and sleep // Estimate amount of time the first future will time out in, and sleep
// for that long. // for that long.
// if there are no futures or the first future's RPC timeout is set to 0 // if there are no futures or the first future's RPC timeout is set to 0
@ -509,7 +526,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
futureTimeoutCV_.wait_for(lock, sleepTime); futureTimeoutCV_.wait_for(lock, sleepTime);
} }
if (shutdown_.load()) { if (!rpcRunning_.load()) {
return; return;
} }

View File

@ -59,6 +59,8 @@ class ProcessGroupAgent : public RpcAgent {
void start() override; void start() override;
void shutdown() override;
protected: protected:
// This method wraps the destination information and the message into a // This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will // SendWork object, and put the SendWork into a queue. Another thread will
@ -145,10 +147,6 @@ class ProcessGroupAgent : public RpcAgent {
return ++nextId_; return ++nextId_;
} }
// atomic bool indicating if join() has been called and background threads
// should shutdown.
std::atomic_bool shutdown_;
std::shared_ptr<c10d::ProcessGroup> pg_; std::shared_ptr<c10d::ProcessGroup> pg_;
// worker name -> rank // worker name -> rank
std::unordered_map<std::string, int> nameMap_; std::unordered_map<std::string, int> nameMap_;
@ -161,12 +159,23 @@ class ProcessGroupAgent : public RpcAgent {
MessageCounter recvCounts_; MessageCounter recvCounts_;
std::atomic<int64_t> nextId_; std::atomic<int64_t> nextId_;
// atomic bool indicating if this agent is running. It is set in
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
// ProcessGroupAgent::join. It controls whether several background threads
// should be running.
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
// conditions when notifying condition variables.
std::atomic<bool> rpcRunning_{false};
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe // one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag. // when using the same tag.
std::vector<std::mutex> sendMutexes_; std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_; std::thread listenerThread_;
// A thread to poll existing futures and check for timed out ones. // A thread to poll existing futures and check for timed out ones.
std::thread futureTimeoutThread_; std::thread futureTimeoutThread_;
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
// A threadPool that processing both SendWork and RecvWork. There are two // A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool: // motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive, // (1) RPC serialization/deserialization and processing can be expensive,

View File

@ -127,7 +127,11 @@ class TORCH_API RpcAgent {
virtual void sync() = 0; virtual void sync() = 0;
// start accepting requests // start accepting requests
virtual void start() {} virtual void start() = 0;
// Stop accepting requests and shutdown the RPC framework as soon as possible
// by terminating all RPC threads.
virtual void shutdown() = 0;
// Set the default rpc agent. // Set the default rpc agent.
static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent); static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);

View File

@ -58,12 +58,34 @@ def _require_initialized(func):
return wrapper return wrapper
def wait_all_workers(): def _wait_all_workers():
r""" r"""
Block until all local and remote RPC processes reach this method, and then Block until all local and remote RPC processes reach this method and wait
destroy local the RPC agent. Every RPC process must call this method before for all outstanding work to complete. Every RPC process must call this
exit. This should be used to terminate the RPC framework, and there is no method before exit to perform a graceful shutdown. This should be used to
guarantee that the RPC framework will work after this method returns. terminate the RPC framework, and there is no guarantee that the RPC
framework will work after this method returns.
"""
global _agent
if _agent:
_agent.join()
def shutdown(graceful=True):
r"""
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
stops the local agent from accepting outstanding requests, and shuts
down the RPC framework by terminating all RPC threads. If graceful=True,
then this will block until all local and remote RPC processes reach this
method and wait for all outstanding work to complete. Otherwise, if
graceful=False, then this is a local shutdown, and it does not wait for
other RPC processes to reach this method.
Arguments:
graceful (bool): Whether to do a graceful shutdown or not. If True,
this will block until all local and remote RPC
processes have reached this method and wait for all
outstanding work to complete.
Example:: Example::
@ -73,27 +95,28 @@ def wait_all_workers():
>>> # do some work >>> # do some work
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
>>> # ready to shutdown >>> # ready to shutdown
>>> rpc.wait_all_workers() >>> rpc.shutdown()
On worker 1: On worker 1:
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> # wait for worker 0 to finish work, and then shutdown. >>> # wait for worker 0 to finish work, and then shutdown.
>>> rpc.wait_all_workers() >>> rpc.shutdown()
""" """
global _agent global _agent
if _agent: if _agent:
_agent.join() if graceful:
_agent = None _wait_all_workers()
_destroy_rref_context(_ignore_rref_leak) _destroy_rref_context(_ignore_rref_leak)
# clean up python rpc handler in wait_all_workers(), see comments in _agent.shutdown()
# clean up python rpc handler in shutdown(), see comments in
# PythonRpcHandler::cleanup(), call it in python API because the # PythonRpcHandler::cleanup(), call it in python API because the
# cleanup() function has python dependency, it assumes python # cleanup() function has python dependency, it assumes python
# interpreter exists # interpreter exists
_cleanup_python_rpc_handler() _cleanup_python_rpc_handler()
_agent = None
# TODO: add a context manager to wrap _init_rpc_backend and wait_all_workers # TODO: add a context manager to wrap _init_rpc_backend and shutdown
def _init_rpc_backend( def _init_rpc_backend(
backend=backend_registry.BackendType.PROCESS_GROUP, backend=backend_registry.BackendType.PROCESS_GROUP,
store=None, store=None,
@ -205,12 +228,12 @@ def remote(to, func, args=None, kwargs=None):
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here() >>> x = rref1.to_here() + rref2.to_here()
>>> rpc.wait_all_workers() >>> rpc.shutdown()
On worker 1: On worker 1:
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers() >>> rpc.shutdown()
""" """
qualified_name = torch.jit._find_builtin(func) qualified_name = torch.jit._find_builtin(func)
@ -274,12 +297,12 @@ def rpc_sync(to, func, args=None, kwargs=None):
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.wait_all_workers() >>> rpc.shutdown()
On worker 1: On worker 1:
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers() >>> rpc.shutdown()
""" """
fut = _invoke_rpc(to, func, args, kwargs) fut = _invoke_rpc(to, func, args, kwargs)
return fut.wait() return fut.wait()
@ -314,12 +337,12 @@ def rpc_async(to, func, args=None, kwargs=None):
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait() >>> result = fut1.wait() + fut2.wait()
>>> rpc.wait_all_workers() >>> rpc.shutdown()
On worker 1: On worker 1:
>>> import torch.distributed.rpc as rpc >>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers() >>> rpc.shutdown()
""" """
fut = _invoke_rpc(to, func, args, kwargs) fut = _invoke_rpc(to, func, args, kwargs)
return fut return fut