Add local shutdown to process group agent (#30330)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30330

This is now possible due to previous changes made in `gloo` and `ProcessGroupGloo`. We `abort` the listener thread that is waiting for a message, and join all other threads. The API is changed so that the previous `wait_all_workers` does not destroy the agent, and this is now done in a new `shutdown` method. All callsites are updated appropriately.

ghstack-source-id: 94673884
ghstack-source-id: 94673884

Test Plan: Unit tests pass.

Reviewed By: mrshenli

Differential Revision: D18661775

fbshipit-source-id: 5aaa7c14603e18253394224994f6cd43234301c2
This commit is contained in:
Rohan Varma
2019-11-27 22:31:38 -08:00
committed by Facebook Github Bot
parent 7ac8efa689
commit 1350b99de4
10 changed files with 230 additions and 94 deletions

View File

@ -314,13 +314,14 @@ The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows:
Simple end to end example
^^^^^^^^^^^^^^^^^^^^^^^^^
Putting it all together, a very simple end to end example using distributed
autograd and distributed optimizer is as follows:
Putting it all together, the following is a simple end to end example using
distributed autograd and the distributed optimizer. If the code is placed into a
file called "dist_autograd_simple.py", it can be run with the command
:code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`:
.. code::
import multiprocessing as mp
from tempfile import NamedTemporaryFile
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
@ -330,50 +331,52 @@ autograd and distributed optimizer is as follows:
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(self_rank, dst_rank):
self_name = "worker{}".format(self_rank)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
self_name=self_name,
self_rank=self_rank,
worker_name_to_id={"worker0": 0, "worker1": 1},
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward([loss.sum()])
# Backward pass (run distributed autograd).
dist_autograd.backward([loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step()
# Run the distributed optimizer step.
dist_optim.step()
def run_process(self_rank, dst_rank):
_run_process(self_rank, dst_rank)
rpc.wait_all_workers()
def run_process(rank, dst_rank, world_size):
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
processes = []
# Run two workers.
for i in range(2):
p = mp.Process(target=run_process, args=(i, (i + 1) % 2))
# Run world_size workers.
world_size = 2
for i in range(world_size):
p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()
.. _RFC: https://github.com/pytorch/pytorch/issues/23110

View File

@ -55,7 +55,7 @@ This library provides primitives allowing users to create and modify references
.. autofunction:: rpc_async
.. autofunction:: remote
.. autofunction:: get_worker_info
.. autofunction:: wait_all_workers
.. autofunction:: shutdown
Distributed Autograd Framework
------------------------------

View File

@ -1360,7 +1360,7 @@ class DistAutogradTest(RpcAgentTestFixture):
# receive gradients from the node that received an error (and as a
# result it didn't execute the rest of the graph).
dist.barrier()
rpc.wait_all_workers()
rpc.shutdown()
sys.exit(0)

View File

@ -127,7 +127,7 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
# since we need to shutdown the RPC agent. If we don't shutdown the
# RPC agent, tests would fail since RPC agent threads, locks and
# condition variables are not properly terminated.
rpc.wait_all_workers()
rpc.shutdown()
return return_value

View File

@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import concurrent.futures
from datetime import timedelta
import sys
import time
import unittest
from collections import namedtuple
from unittest import mock
@ -18,6 +19,21 @@ from torch.distributed.rpc.api import _use_rpc_pickler
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
from rpc_agent_test_fixture import RpcAgentTestFixture
rpc_done = [False, False, False, False]
# TODO: dedupe this with the code in dist_autograd_test.py.
# Send rpc done info and context_id to
# dst_rank = (self.rank + rank_distance) % self.world_size
# we don't need a lock here since the GIL is held while executing remote
# python UDFs, so access is serialized across several workers.
def _set_rpc_done(rank_distance):
global rpc_done
rpc_done[rank_distance] = True
def _check_rpc_done(rank_distance):
while not rpc_done[rank_distance]:
# yield control to other threads
time.sleep(0)
def requires_process_group_agent(message=""):
def decorator(old_func):
@ -127,7 +143,6 @@ def my_tensor_function(a, b):
return a + b
def my_sleep_func(seconds=1):
import time
time.sleep(seconds)
@ -346,7 +361,7 @@ class RpcTest(RpcAgentTestFixture):
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.wait_all_workers()
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_reinit(self):
@ -379,8 +394,9 @@ class RpcTest(RpcAgentTestFixture):
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
rpc.wait_all_workers()
rpc.shutdown()
@unittest.skip("test_invalid_names is flaky, see https://github.com/pytorch/pytorch/issues/25912")
@dist_init(setup_rpc=False)
def test_invalid_names(self):
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
@ -447,8 +463,8 @@ class RpcTest(RpcAgentTestFixture):
from torch.distributed.rpc.api import _agent
self.assertEqual(_agent, None)
# wait_all_workers() should not do anything as _agent is None
rpc.wait_all_workers()
# shutdown() should not do anything as _agent is None
rpc.shutdown()
# We need this barrier here because although init_process_group is
# blocking, it does not guarantee that all ranks are done with
# initialization after the call. We did run into issues with it where
@ -525,7 +541,7 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init(setup_rpc=False)
def test_wait_all_workers(self):
def test_shutdown(self):
# Initialize RPC.
rpc.init_rpc(
name="worker%d" % self.rank,
@ -543,7 +559,7 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
rpc.wait_all_workers()
rpc.shutdown()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
rpc.rpc_sync(
@ -552,8 +568,8 @@ class RpcTest(RpcAgentTestFixture):
args=(torch.ones(n, n), torch.ones(n, n)),
)
# it's safe to call wait_all_workers() multiple times
rpc.wait_all_workers()
# it's safe to call shutdown() multiple times
rpc.shutdown()
@dist_init
def test_expected_src(self):
@ -717,8 +733,6 @@ class RpcTest(RpcAgentTestFixture):
self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()):
import time
n = self.rank + 1
dst_rank = n % self.world_size
futs = []
@ -1138,7 +1152,7 @@ class RpcTest(RpcAgentTestFixture):
import torch.distributed.rpc.api as api
api._ignore_rref_leak = True
rpc.wait_all_workers()
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_rref_leak(self):
@ -1149,6 +1163,68 @@ class RpcTest(RpcAgentTestFixture):
def test_ignore_rref_leak(self):
self._test_rref_leak(ignore_leak=True)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown(self):
# test that we can start RPC and then immediately locally shutdown
# without sending any messages.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_local_shutdown_with_rpc(self):
# test that we can start RPC, send RPCs, and then run local shutdown.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[
dist_utils.TEST_CONFIG.rpc_backend_name
],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options,
)
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# wait for RPCs to be done, so that some workers don't try to shut down
# too early.
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(1,))
_check_rpc_done(1)
# pass in graceful=False to ensure that we don't wait for other workers.
rpc.shutdown(graceful=False)
@dist_init(setup_rpc=False)
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
def test_wait_all_workers_and_shutdown(self):
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
# called without errors being raised due to attempting to shut down
# multiple times.
rpc.init_rpc(
name="worker%d" % self.rank,
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=self.rpc_backend_options
)
from torch.distributed.rpc.api import _wait_all_workers
# intentional call to internal _wait_all_workers.
_wait_all_workers()
rpc.shutdown()
@dist_init(setup_rpc=False)
def test_get_rpc_timeout(self):
timeout = timedelta(seconds=1)
@ -1167,7 +1243,7 @@ class RpcTest(RpcAgentTestFixture):
)
set_timeout = rpc.get_rpc_timeout()
self.assertEqual(timeout, set_timeout)
rpc.wait_all_workers()
rpc.shutdown()
@dist_init
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")

View File

@ -114,9 +114,9 @@ Otherwise, throws an exception.
return PyRRef::unpickle(t);
}));
// future.wait() should not be called after wait_all_workers(), e.g.,
// pythonRpcHandler is cleaned up in wait_all_workers(), after
// wait_all_workers(), python objects returned from rpc python call can not be
// future.wait() should not be called after shutdown(), e.g.,
// pythonRpcHandler is cleaned up in shutdown(), after
// shutdown(), python objects returned from rpc python call can not be
// resolved.
auto futureMessage =
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
@ -162,6 +162,10 @@ Otherwise, throws an exception.
"join",
&ProcessGroupAgent::join,
py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&ProcessGroupAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"sync",
&ProcessGroupAgent::sync,

View File

@ -127,7 +127,6 @@ ProcessGroupAgent::ProcessGroupAgent(
WorkerInfo(std::move(workerName), pg->getRank()),
c10::guts::make_unique<RequestCallbackImpl>(),
rpcTimeout),
shutdown_{false},
pg_(std::move(pg)),
sendCounts_(pg_->getSize()),
recvCounts_(pg_->getSize()),
@ -184,30 +183,12 @@ std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
}
void ProcessGroupAgent::join() {
// Every process i sends a SHUTDOWN message to process i + 1. This is
// necessary for now because:
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
// feed it a message or kill the thread.
// 2. A GLOO process cannot send message to itself. (there is an ongoing
// effort to fix this problem).
shutdown_.store(true);
sync();
// This is needed in case no futures were created, otherwise the future
// timeout watchdog would sleep forever.
futureTimeoutCV_.notify_one();
std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait(
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
lock.unlock();
pg_->barrier()->wait();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueueSend(
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
threadPool_.waitWorkComplete();
listenerThread_.join();
futureTimeoutThread_.join();
PythonRpcHandler::getInstance().cleanup();
}
bool ProcessGroupAgent::hasPendingMessage() {
@ -273,14 +254,38 @@ void ProcessGroupAgent::sync() {
}
void ProcessGroupAgent::start() {
{
std::lock_guard<std::mutex> futureLock{futureMutex_};
rpcRunning_.store(true);
}
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
futureTimeoutThread_ =
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
}
void ProcessGroupAgent::shutdown() {
LOG(INFO) << "Shutting down ProcessGroupAgent.";
std::unique_lock<std::mutex> lock{futureMutex_};
if (!rpcRunning_.exchange(false)) {
return;
}
lock.unlock();
futureTimeoutCV_.notify_one();
futureTimeoutThread_.join();
{
std::unique_lock<std::mutex> lock(recvWorkMutex_);
if (recvWork_) {
recvWork_->abort();
}
}
threadPool_.waitWorkComplete();
listenerThread_.join();
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerInfo& to,
Message&& message) {
TORCH_CHECK(rpcRunning_.load(), "ProcessGroupAgent hasn't started.")
TORCH_CHECK(
to.id_ < (worker_id_t)pg_->getSize(),
"Destination rank is out of bound, got ",
@ -460,10 +465,19 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
}
void ProcessGroupAgent::listenLoop() {
while (true) {
while (rpcRunning_.load()) {
// rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
pg_->recvAnysource(preamble, pg_->getRank())->wait();
auto work = pg_->recvAnysource(preamble, pg_->getRank());
{
std::lock_guard<std::mutex> guard(recvWorkMutex_);
recvWork_ = work;
}
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
return;
}
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0];
@ -487,9 +501,12 @@ void ProcessGroupAgent::listenLoop() {
}
void ProcessGroupAgent::pollTimedOutRPCs() {
while (!shutdown_.load()) {
std::chrono::milliseconds sleepTime;
while (true) {
std::unique_lock<std::mutex> lock{futureMutex_};
if (!rpcRunning_.load()) {
return;
}
std::chrono::milliseconds sleepTime;
// Estimate amount of time the first future will time out in, and sleep
// for that long.
// if there are no futures or the first future's RPC timeout is set to 0
@ -509,7 +526,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
futureTimeoutCV_.wait_for(lock, sleepTime);
}
if (shutdown_.load()) {
if (!rpcRunning_.load()) {
return;
}

View File

@ -59,6 +59,8 @@ class ProcessGroupAgent : public RpcAgent {
void start() override;
void shutdown() override;
protected:
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
@ -145,10 +147,6 @@ class ProcessGroupAgent : public RpcAgent {
return ++nextId_;
}
// atomic bool indicating if join() has been called and background threads
// should shutdown.
std::atomic_bool shutdown_;
std::shared_ptr<c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, int> nameMap_;
@ -161,12 +159,23 @@ class ProcessGroupAgent : public RpcAgent {
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
// atomic bool indicating if this agent is running. It is set in
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
// ProcessGroupAgent::join. It controls whether several background threads
// should be running.
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
// conditions when notifying condition variables.
std::atomic<bool> rpcRunning_{false};
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_;
// A thread to poll existing futures and check for timed out ones.
std::thread futureTimeoutThread_;
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
// A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive,

View File

@ -127,7 +127,11 @@ class TORCH_API RpcAgent {
virtual void sync() = 0;
// start accepting requests
virtual void start() {}
virtual void start() = 0;
// Stop accepting requests and shutdown the RPC framework as soon as possible
// by terminating all RPC threads.
virtual void shutdown() = 0;
// Set the default rpc agent.
static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);

View File

@ -58,12 +58,34 @@ def _require_initialized(func):
return wrapper
def wait_all_workers():
def _wait_all_workers():
r"""
Block until all local and remote RPC processes reach this method, and then
destroy local the RPC agent. Every RPC process must call this method before
exit. This should be used to terminate the RPC framework, and there is no
guarantee that the RPC framework will work after this method returns.
Block until all local and remote RPC processes reach this method and wait
for all outstanding work to complete. Every RPC process must call this
method before exit to perform a graceful shutdown. This should be used to
terminate the RPC framework, and there is no guarantee that the RPC
framework will work after this method returns.
"""
global _agent
if _agent:
_agent.join()
def shutdown(graceful=True):
r"""
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
stops the local agent from accepting outstanding requests, and shuts
down the RPC framework by terminating all RPC threads. If graceful=True,
then this will block until all local and remote RPC processes reach this
method and wait for all outstanding work to complete. Otherwise, if
graceful=False, then this is a local shutdown, and it does not wait for
other RPC processes to reach this method.
Arguments:
graceful (bool): Whether to do a graceful shutdown or not. If True,
this will block until all local and remote RPC
processes have reached this method and wait for all
outstanding work to complete.
Example::
@ -73,27 +95,28 @@ def wait_all_workers():
>>> # do some work
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
>>> # ready to shutdown
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> # wait for worker 0 to finish work, and then shutdown.
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
"""
global _agent
if _agent:
_agent.join()
_agent = None
if graceful:
_wait_all_workers()
_destroy_rref_context(_ignore_rref_leak)
# clean up python rpc handler in wait_all_workers(), see comments in
_agent.shutdown()
# clean up python rpc handler in shutdown(), see comments in
# PythonRpcHandler::cleanup(), call it in python API because the
# cleanup() function has python dependency, it assumes python
# interpreter exists
_cleanup_python_rpc_handler()
_agent = None
# TODO: add a context manager to wrap _init_rpc_backend and wait_all_workers
# TODO: add a context manager to wrap _init_rpc_backend and shutdown
def _init_rpc_backend(
backend=backend_registry.BackendType.PROCESS_GROUP,
store=None,
@ -205,12 +228,12 @@ def remote(to, func, args=None, kwargs=None):
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
"""
qualified_name = torch.jit._find_builtin(func)
@ -274,12 +297,12 @@ def rpc_sync(to, func, args=None, kwargs=None):
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut.wait()
@ -314,12 +337,12 @@ def rpc_async(to, func, args=None, kwargs=None):
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.wait_all_workers()
>>> rpc.shutdown()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut