mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add local shutdown to process group agent (#30330)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30330 This is now possible due to previous changes made in `gloo` and `ProcessGroupGloo`. We `abort` the listener thread that is waiting for a message, and join all other threads. The API is changed so that the previous `wait_all_workers` does not destroy the agent, and this is now done in a new `shutdown` method. All callsites are updated appropriately. ghstack-source-id: 94673884 ghstack-source-id: 94673884 Test Plan: Unit tests pass. Reviewed By: mrshenli Differential Revision: D18661775 fbshipit-source-id: 5aaa7c14603e18253394224994f6cd43234301c2
This commit is contained in:
committed by
Facebook Github Bot
parent
7ac8efa689
commit
1350b99de4
@ -314,13 +314,14 @@ The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows:
|
||||
Simple end to end example
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Putting it all together, a very simple end to end example using distributed
|
||||
autograd and distributed optimizer is as follows:
|
||||
Putting it all together, the following is a simple end to end example using
|
||||
distributed autograd and the distributed optimizer. If the code is placed into a
|
||||
file called "dist_autograd_simple.py", it can be run with the command
|
||||
:code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`:
|
||||
|
||||
.. code::
|
||||
|
||||
import multiprocessing as mp
|
||||
from tempfile import NamedTemporaryFile
|
||||
import torch
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
from torch.distributed import rpc
|
||||
@ -330,50 +331,52 @@ autograd and distributed optimizer is as follows:
|
||||
def random_tensor():
|
||||
return torch.rand((3, 3), requires_grad=True)
|
||||
|
||||
def _run_process(self_rank, dst_rank):
|
||||
self_name = "worker{}".format(self_rank)
|
||||
def _run_process(rank, dst_rank, world_size):
|
||||
name = "worker{}".format(rank)
|
||||
dst_name = "worker{}".format(dst_rank)
|
||||
|
||||
# Initialize RPC.
|
||||
rpc.init_rpc(
|
||||
self_name=self_name,
|
||||
self_rank=self_rank,
|
||||
worker_name_to_id={"worker0": 0, "worker1": 1},
|
||||
name=name,
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
# Use a distributed autograd context.
|
||||
with dist_autograd.context() as context_id:
|
||||
# Forward pass (create references on remote nodes).
|
||||
rref1 = rpc.remote(dst_name, random_tensor)
|
||||
rref2 = rpc.remote(dst_name, random_tensor)
|
||||
loss = rref1.to_here() + rref2.to_here()
|
||||
# Forward pass (create references on remote nodes).
|
||||
rref1 = rpc.remote(dst_name, random_tensor)
|
||||
rref2 = rpc.remote(dst_name, random_tensor)
|
||||
loss = rref1.to_here() + rref2.to_here()
|
||||
|
||||
# Backward pass (run distributed autograd).
|
||||
dist_autograd.backward([loss.sum()])
|
||||
# Backward pass (run distributed autograd).
|
||||
dist_autograd.backward([loss.sum()])
|
||||
|
||||
# Build DistributedOptimizer.
|
||||
dist_optim = DistributedOptimizer(
|
||||
optim.SGD,
|
||||
[rref1, rref2],
|
||||
lr=0.05,
|
||||
)
|
||||
# Build DistributedOptimizer.
|
||||
dist_optim = DistributedOptimizer(
|
||||
optim.SGD,
|
||||
[rref1, rref2],
|
||||
lr=0.05,
|
||||
)
|
||||
|
||||
# Run the distributed optimizer step.
|
||||
dist_optim.step()
|
||||
# Run the distributed optimizer step.
|
||||
dist_optim.step()
|
||||
|
||||
def run_process(self_rank, dst_rank):
|
||||
_run_process(self_rank, dst_rank)
|
||||
rpc.wait_all_workers()
|
||||
def run_process(rank, dst_rank, world_size):
|
||||
_run_process(rank, dst_rank, world_size)
|
||||
rpc.shutdown()
|
||||
|
||||
processes = []
|
||||
|
||||
# Run two workers.
|
||||
for i in range(2):
|
||||
p = mp.Process(target=run_process, args=(i, (i + 1) % 2))
|
||||
# Run world_size workers.
|
||||
world_size = 2
|
||||
for i in range(world_size):
|
||||
p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
|
||||
.. _RFC: https://github.com/pytorch/pytorch/issues/23110
|
||||
|
@ -55,7 +55,7 @@ This library provides primitives allowing users to create and modify references
|
||||
.. autofunction:: rpc_async
|
||||
.. autofunction:: remote
|
||||
.. autofunction:: get_worker_info
|
||||
.. autofunction:: wait_all_workers
|
||||
.. autofunction:: shutdown
|
||||
|
||||
Distributed Autograd Framework
|
||||
------------------------------
|
||||
|
@ -1360,7 +1360,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
# receive gradients from the node that received an error (and as a
|
||||
# result it didn't execute the rest of the graph).
|
||||
dist.barrier()
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
|
@ -127,7 +127,7 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
|
||||
# since we need to shutdown the RPC agent. If we don't shutdown the
|
||||
# RPC agent, tests would fail since RPC agent threads, locks and
|
||||
# condition variables are not properly terminated.
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
return return_value
|
||||
|
||||
|
102
test/rpc_test.py
102
test/rpc_test.py
@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
import concurrent.futures
|
||||
from datetime import timedelta
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from unittest import mock
|
||||
@ -18,6 +19,21 @@ from torch.distributed.rpc.api import _use_rpc_pickler
|
||||
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
||||
from rpc_agent_test_fixture import RpcAgentTestFixture
|
||||
|
||||
rpc_done = [False, False, False, False]
|
||||
|
||||
# TODO: dedupe this with the code in dist_autograd_test.py.
|
||||
# Send rpc done info and context_id to
|
||||
# dst_rank = (self.rank + rank_distance) % self.world_size
|
||||
# we don't need a lock here since the GIL is held while executing remote
|
||||
# python UDFs, so access is serialized across several workers.
|
||||
def _set_rpc_done(rank_distance):
|
||||
global rpc_done
|
||||
rpc_done[rank_distance] = True
|
||||
|
||||
def _check_rpc_done(rank_distance):
|
||||
while not rpc_done[rank_distance]:
|
||||
# yield control to other threads
|
||||
time.sleep(0)
|
||||
|
||||
def requires_process_group_agent(message=""):
|
||||
def decorator(old_func):
|
||||
@ -127,7 +143,6 @@ def my_tensor_function(a, b):
|
||||
return a + b
|
||||
|
||||
def my_sleep_func(seconds=1):
|
||||
import time
|
||||
time.sleep(seconds)
|
||||
|
||||
|
||||
@ -346,7 +361,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=self.rpc_backend_options,
|
||||
)
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_reinit(self):
|
||||
@ -379,8 +394,9 @@ class RpcTest(RpcAgentTestFixture):
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=self.rpc_backend_options,
|
||||
)
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
@unittest.skip("test_invalid_names is flaky, see https://github.com/pytorch/pytorch/issues/25912")
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_invalid_names(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
||||
@ -447,8 +463,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
|
||||
from torch.distributed.rpc.api import _agent
|
||||
self.assertEqual(_agent, None)
|
||||
# wait_all_workers() should not do anything as _agent is None
|
||||
rpc.wait_all_workers()
|
||||
# shutdown() should not do anything as _agent is None
|
||||
rpc.shutdown()
|
||||
# We need this barrier here because although init_process_group is
|
||||
# blocking, it does not guarantee that all ranks are done with
|
||||
# initialization after the call. We did run into issues with it where
|
||||
@ -525,7 +541,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_wait_all_workers(self):
|
||||
def test_shutdown(self):
|
||||
# Initialize RPC.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
@ -543,7 +559,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
||||
rpc.rpc_sync(
|
||||
@ -552,8 +568,8 @@ class RpcTest(RpcAgentTestFixture):
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
|
||||
# it's safe to call wait_all_workers() multiple times
|
||||
rpc.wait_all_workers()
|
||||
# it's safe to call shutdown() multiple times
|
||||
rpc.shutdown()
|
||||
|
||||
@dist_init
|
||||
def test_expected_src(self):
|
||||
@ -717,8 +733,6 @@ class RpcTest(RpcAgentTestFixture):
|
||||
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
||||
|
||||
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
||||
import time
|
||||
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
futs = []
|
||||
@ -1138,7 +1152,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
import torch.distributed.rpc.api as api
|
||||
api._ignore_rref_leak = True
|
||||
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_rref_leak(self):
|
||||
@ -1149,6 +1163,68 @@ class RpcTest(RpcAgentTestFixture):
|
||||
def test_ignore_rref_leak(self):
|
||||
self._test_rref_leak(ignore_leak=True)
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||
def test_local_shutdown(self):
|
||||
# test that we can start RPC and then immediately locally shutdown
|
||||
# without sending any messages.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
backend=rpc.backend_registry.BackendType[
|
||||
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||
],
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=self.rpc_backend_options,
|
||||
)
|
||||
# pass in graceful=False to ensure that we don't wait for other workers.
|
||||
rpc.shutdown(graceful=False)
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||
def test_local_shutdown_with_rpc(self):
|
||||
# test that we can start RPC, send RPCs, and then run local shutdown.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
backend=rpc.backend_registry.BackendType[
|
||||
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||
],
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=self.rpc_backend_options,
|
||||
)
|
||||
n = self.rank + 1
|
||||
dst_rank = n % self.world_size
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
torch.add,
|
||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||
)
|
||||
# wait for RPCs to be done, so that some workers don't try to shut down
|
||||
# too early.
|
||||
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(1,))
|
||||
_check_rpc_done(1)
|
||||
# pass in graceful=False to ensure that we don't wait for other workers.
|
||||
rpc.shutdown(graceful=False)
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||
def test_wait_all_workers_and_shutdown(self):
|
||||
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
|
||||
# called without errors being raised due to attempting to shut down
|
||||
# multiple times.
|
||||
rpc.init_rpc(
|
||||
name="worker%d" % self.rank,
|
||||
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
rpc_backend_options=self.rpc_backend_options
|
||||
)
|
||||
from torch.distributed.rpc.api import _wait_all_workers
|
||||
# intentional call to internal _wait_all_workers.
|
||||
_wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
@dist_init(setup_rpc=False)
|
||||
def test_get_rpc_timeout(self):
|
||||
timeout = timedelta(seconds=1)
|
||||
@ -1167,7 +1243,7 @@ class RpcTest(RpcAgentTestFixture):
|
||||
)
|
||||
set_timeout = rpc.get_rpc_timeout()
|
||||
self.assertEqual(timeout, set_timeout)
|
||||
rpc.wait_all_workers()
|
||||
rpc.shutdown()
|
||||
|
||||
@dist_init
|
||||
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||
|
@ -114,9 +114,9 @@ Otherwise, throws an exception.
|
||||
return PyRRef::unpickle(t);
|
||||
}));
|
||||
|
||||
// future.wait() should not be called after wait_all_workers(), e.g.,
|
||||
// pythonRpcHandler is cleaned up in wait_all_workers(), after
|
||||
// wait_all_workers(), python objects returned from rpc python call can not be
|
||||
// future.wait() should not be called after shutdown(), e.g.,
|
||||
// pythonRpcHandler is cleaned up in shutdown(), after
|
||||
// shutdown(), python objects returned from rpc python call can not be
|
||||
// resolved.
|
||||
auto futureMessage =
|
||||
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
|
||||
@ -162,6 +162,10 @@ Otherwise, throws an exception.
|
||||
"join",
|
||||
&ProcessGroupAgent::join,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"shutdown",
|
||||
&ProcessGroupAgent::shutdown,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"sync",
|
||||
&ProcessGroupAgent::sync,
|
||||
|
@ -127,7 +127,6 @@ ProcessGroupAgent::ProcessGroupAgent(
|
||||
WorkerInfo(std::move(workerName), pg->getRank()),
|
||||
c10::guts::make_unique<RequestCallbackImpl>(),
|
||||
rpcTimeout),
|
||||
shutdown_{false},
|
||||
pg_(std::move(pg)),
|
||||
sendCounts_(pg_->getSize()),
|
||||
recvCounts_(pg_->getSize()),
|
||||
@ -184,30 +183,12 @@ std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::join() {
|
||||
// Every process i sends a SHUTDOWN message to process i + 1. This is
|
||||
// necessary for now because:
|
||||
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
|
||||
// feed it a message or kill the thread.
|
||||
// 2. A GLOO process cannot send message to itself. (there is an ongoing
|
||||
// effort to fix this problem).
|
||||
shutdown_.store(true);
|
||||
sync();
|
||||
// This is needed in case no futures were created, otherwise the future
|
||||
// timeout watchdog would sleep forever.
|
||||
|
||||
futureTimeoutCV_.notify_one();
|
||||
std::unique_lock<std::mutex> lock(futureMutex_);
|
||||
futureCV_.wait(
|
||||
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
|
||||
lock.unlock();
|
||||
pg_->barrier()->wait();
|
||||
int dst = (pg_->getRank() + 1) % pg_->getSize();
|
||||
enqueueSend(
|
||||
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
|
||||
threadPool_.waitWorkComplete();
|
||||
listenerThread_.join();
|
||||
futureTimeoutThread_.join();
|
||||
PythonRpcHandler::getInstance().cleanup();
|
||||
}
|
||||
|
||||
bool ProcessGroupAgent::hasPendingMessage() {
|
||||
@ -273,14 +254,38 @@ void ProcessGroupAgent::sync() {
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::start() {
|
||||
{
|
||||
std::lock_guard<std::mutex> futureLock{futureMutex_};
|
||||
rpcRunning_.store(true);
|
||||
}
|
||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||
futureTimeoutThread_ =
|
||||
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::shutdown() {
|
||||
LOG(INFO) << "Shutting down ProcessGroupAgent.";
|
||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||
if (!rpcRunning_.exchange(false)) {
|
||||
return;
|
||||
}
|
||||
lock.unlock();
|
||||
futureTimeoutCV_.notify_one();
|
||||
futureTimeoutThread_.join();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(recvWorkMutex_);
|
||||
if (recvWork_) {
|
||||
recvWork_->abort();
|
||||
}
|
||||
}
|
||||
threadPool_.waitWorkComplete();
|
||||
listenerThread_.join();
|
||||
}
|
||||
|
||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||
const WorkerInfo& to,
|
||||
Message&& message) {
|
||||
TORCH_CHECK(rpcRunning_.load(), "ProcessGroupAgent hasn't started.")
|
||||
TORCH_CHECK(
|
||||
to.id_ < (worker_id_t)pg_->getSize(),
|
||||
"Destination rank is out of bound, got ",
|
||||
@ -460,10 +465,19 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoop() {
|
||||
while (true) {
|
||||
while (rpcRunning_.load()) {
|
||||
// rank, tensor size, message type
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
|
||||
pg_->recvAnysource(preamble, pg_->getRank())->wait();
|
||||
auto work = pg_->recvAnysource(preamble, pg_->getRank());
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(recvWorkMutex_);
|
||||
recvWork_ = work;
|
||||
}
|
||||
|
||||
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
||||
|
||||
auto srcRank = preamble_items[0];
|
||||
@ -487,9 +501,12 @@ void ProcessGroupAgent::listenLoop() {
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
while (!shutdown_.load()) {
|
||||
std::chrono::milliseconds sleepTime;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||
if (!rpcRunning_.load()) {
|
||||
return;
|
||||
}
|
||||
std::chrono::milliseconds sleepTime;
|
||||
// Estimate amount of time the first future will time out in, and sleep
|
||||
// for that long.
|
||||
// if there are no futures or the first future's RPC timeout is set to 0
|
||||
@ -509,7 +526,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
futureTimeoutCV_.wait_for(lock, sleepTime);
|
||||
}
|
||||
|
||||
if (shutdown_.load()) {
|
||||
if (!rpcRunning_.load()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -59,6 +59,8 @@ class ProcessGroupAgent : public RpcAgent {
|
||||
|
||||
void start() override;
|
||||
|
||||
void shutdown() override;
|
||||
|
||||
protected:
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
@ -145,10 +147,6 @@ class ProcessGroupAgent : public RpcAgent {
|
||||
return ++nextId_;
|
||||
}
|
||||
|
||||
// atomic bool indicating if join() has been called and background threads
|
||||
// should shutdown.
|
||||
std::atomic_bool shutdown_;
|
||||
|
||||
std::shared_ptr<c10d::ProcessGroup> pg_;
|
||||
// worker name -> rank
|
||||
std::unordered_map<std::string, int> nameMap_;
|
||||
@ -161,12 +159,23 @@ class ProcessGroupAgent : public RpcAgent {
|
||||
MessageCounter recvCounts_;
|
||||
|
||||
std::atomic<int64_t> nextId_;
|
||||
// atomic bool indicating if this agent is running. It is set in
|
||||
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
|
||||
// ProcessGroupAgent::join. It controls whether several background threads
|
||||
// should be running.
|
||||
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
|
||||
// conditions when notifying condition variables.
|
||||
std::atomic<bool> rpcRunning_{false};
|
||||
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
||||
// when using the same tag.
|
||||
std::vector<std::mutex> sendMutexes_;
|
||||
std::thread listenerThread_;
|
||||
// A thread to poll existing futures and check for timed out ones.
|
||||
std::thread futureTimeoutThread_;
|
||||
// Lock and shared ptr to currently pending work, set in listenloop() and
|
||||
// interruptible in shutdown().
|
||||
std::mutex recvWorkMutex_;
|
||||
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
|
||||
// A threadPool that processing both SendWork and RecvWork. There are two
|
||||
// motivations for adding a ThreadPool:
|
||||
// (1) RPC serialization/deserialization and processing can be expensive,
|
||||
|
@ -127,7 +127,11 @@ class TORCH_API RpcAgent {
|
||||
virtual void sync() = 0;
|
||||
|
||||
// start accepting requests
|
||||
virtual void start() {}
|
||||
virtual void start() = 0;
|
||||
|
||||
// Stop accepting requests and shutdown the RPC framework as soon as possible
|
||||
// by terminating all RPC threads.
|
||||
virtual void shutdown() = 0;
|
||||
|
||||
// Set the default rpc agent.
|
||||
static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);
|
||||
|
@ -58,12 +58,34 @@ def _require_initialized(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def wait_all_workers():
|
||||
def _wait_all_workers():
|
||||
r"""
|
||||
Block until all local and remote RPC processes reach this method, and then
|
||||
destroy local the RPC agent. Every RPC process must call this method before
|
||||
exit. This should be used to terminate the RPC framework, and there is no
|
||||
guarantee that the RPC framework will work after this method returns.
|
||||
Block until all local and remote RPC processes reach this method and wait
|
||||
for all outstanding work to complete. Every RPC process must call this
|
||||
method before exit to perform a graceful shutdown. This should be used to
|
||||
terminate the RPC framework, and there is no guarantee that the RPC
|
||||
framework will work after this method returns.
|
||||
"""
|
||||
global _agent
|
||||
|
||||
if _agent:
|
||||
_agent.join()
|
||||
|
||||
def shutdown(graceful=True):
|
||||
r"""
|
||||
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
|
||||
stops the local agent from accepting outstanding requests, and shuts
|
||||
down the RPC framework by terminating all RPC threads. If graceful=True,
|
||||
then this will block until all local and remote RPC processes reach this
|
||||
method and wait for all outstanding work to complete. Otherwise, if
|
||||
graceful=False, then this is a local shutdown, and it does not wait for
|
||||
other RPC processes to reach this method.
|
||||
|
||||
Arguments:
|
||||
graceful (bool): Whether to do a graceful shutdown or not. If True,
|
||||
this will block until all local and remote RPC
|
||||
processes have reached this method and wait for all
|
||||
outstanding work to complete.
|
||||
|
||||
Example::
|
||||
|
||||
@ -73,27 +95,28 @@ def wait_all_workers():
|
||||
>>> # do some work
|
||||
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
|
||||
>>> # ready to shutdown
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> # wait for worker 0 to finish work, and then shutdown.
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
global _agent
|
||||
|
||||
if _agent:
|
||||
_agent.join()
|
||||
_agent = None
|
||||
if graceful:
|
||||
_wait_all_workers()
|
||||
_destroy_rref_context(_ignore_rref_leak)
|
||||
# clean up python rpc handler in wait_all_workers(), see comments in
|
||||
_agent.shutdown()
|
||||
# clean up python rpc handler in shutdown(), see comments in
|
||||
# PythonRpcHandler::cleanup(), call it in python API because the
|
||||
# cleanup() function has python dependency, it assumes python
|
||||
# interpreter exists
|
||||
_cleanup_python_rpc_handler()
|
||||
_agent = None
|
||||
|
||||
# TODO: add a context manager to wrap _init_rpc_backend and wait_all_workers
|
||||
# TODO: add a context manager to wrap _init_rpc_backend and shutdown
|
||||
def _init_rpc_backend(
|
||||
backend=backend_registry.BackendType.PROCESS_GROUP,
|
||||
store=None,
|
||||
@ -205,12 +228,12 @@ def remote(to, func, args=None, kwargs=None):
|
||||
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
|
||||
>>> x = rref1.to_here() + rref2.to_here()
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
qualified_name = torch.jit._find_builtin(func)
|
||||
|
||||
@ -274,12 +297,12 @@ def rpc_sync(to, func, args=None, kwargs=None):
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, args, kwargs)
|
||||
return fut.wait()
|
||||
@ -314,12 +337,12 @@ def rpc_async(to, func, args=None, kwargs=None):
|
||||
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
|
||||
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
|
||||
>>> result = fut1.wait() + fut2.wait()
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.wait_all_workers()
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, args, kwargs)
|
||||
return fut
|
||||
|
Reference in New Issue
Block a user