mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add local shutdown to process group agent (#30330)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30330 This is now possible due to previous changes made in `gloo` and `ProcessGroupGloo`. We `abort` the listener thread that is waiting for a message, and join all other threads. The API is changed so that the previous `wait_all_workers` does not destroy the agent, and this is now done in a new `shutdown` method. All callsites are updated appropriately. ghstack-source-id: 94673884 ghstack-source-id: 94673884 Test Plan: Unit tests pass. Reviewed By: mrshenli Differential Revision: D18661775 fbshipit-source-id: 5aaa7c14603e18253394224994f6cd43234301c2
This commit is contained in:
committed by
Facebook Github Bot
parent
7ac8efa689
commit
1350b99de4
@ -314,13 +314,14 @@ The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows:
|
|||||||
Simple end to end example
|
Simple end to end example
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Putting it all together, a very simple end to end example using distributed
|
Putting it all together, the following is a simple end to end example using
|
||||||
autograd and distributed optimizer is as follows:
|
distributed autograd and the distributed optimizer. If the code is placed into a
|
||||||
|
file called "dist_autograd_simple.py", it can be run with the command
|
||||||
|
:code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`:
|
||||||
|
|
||||||
.. code::
|
.. code::
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.autograd as dist_autograd
|
import torch.distributed.autograd as dist_autograd
|
||||||
from torch.distributed import rpc
|
from torch.distributed import rpc
|
||||||
@ -330,50 +331,52 @@ autograd and distributed optimizer is as follows:
|
|||||||
def random_tensor():
|
def random_tensor():
|
||||||
return torch.rand((3, 3), requires_grad=True)
|
return torch.rand((3, 3), requires_grad=True)
|
||||||
|
|
||||||
def _run_process(self_rank, dst_rank):
|
def _run_process(rank, dst_rank, world_size):
|
||||||
self_name = "worker{}".format(self_rank)
|
name = "worker{}".format(rank)
|
||||||
dst_name = "worker{}".format(dst_rank)
|
dst_name = "worker{}".format(dst_rank)
|
||||||
|
|
||||||
# Initialize RPC.
|
# Initialize RPC.
|
||||||
rpc.init_rpc(
|
rpc.init_rpc(
|
||||||
self_name=self_name,
|
name=name,
|
||||||
self_rank=self_rank,
|
rank=rank,
|
||||||
worker_name_to_id={"worker0": 0, "worker1": 1},
|
world_size=world_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use a distributed autograd context.
|
# Use a distributed autograd context.
|
||||||
with dist_autograd.context() as context_id:
|
with dist_autograd.context() as context_id:
|
||||||
# Forward pass (create references on remote nodes).
|
# Forward pass (create references on remote nodes).
|
||||||
rref1 = rpc.remote(dst_name, random_tensor)
|
rref1 = rpc.remote(dst_name, random_tensor)
|
||||||
rref2 = rpc.remote(dst_name, random_tensor)
|
rref2 = rpc.remote(dst_name, random_tensor)
|
||||||
loss = rref1.to_here() + rref2.to_here()
|
loss = rref1.to_here() + rref2.to_here()
|
||||||
|
|
||||||
# Backward pass (run distributed autograd).
|
# Backward pass (run distributed autograd).
|
||||||
dist_autograd.backward([loss.sum()])
|
dist_autograd.backward([loss.sum()])
|
||||||
|
|
||||||
# Build DistributedOptimizer.
|
# Build DistributedOptimizer.
|
||||||
dist_optim = DistributedOptimizer(
|
dist_optim = DistributedOptimizer(
|
||||||
optim.SGD,
|
optim.SGD,
|
||||||
[rref1, rref2],
|
[rref1, rref2],
|
||||||
lr=0.05,
|
lr=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the distributed optimizer step.
|
# Run the distributed optimizer step.
|
||||||
dist_optim.step()
|
dist_optim.step()
|
||||||
|
|
||||||
def run_process(self_rank, dst_rank):
|
def run_process(rank, dst_rank, world_size):
|
||||||
_run_process(self_rank, dst_rank)
|
_run_process(rank, dst_rank, world_size)
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
processes = []
|
processes = []
|
||||||
|
|
||||||
# Run two workers.
|
# Run world_size workers.
|
||||||
for i in range(2):
|
world_size = 2
|
||||||
p = mp.Process(target=run_process, args=(i, (i + 1) % 2))
|
for i in range(world_size):
|
||||||
|
p = mp.Process(target=run_process, args=(i, (i + 1) % 2, world_size))
|
||||||
p.start()
|
p.start()
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
|
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
|
||||||
.. _RFC: https://github.com/pytorch/pytorch/issues/23110
|
.. _RFC: https://github.com/pytorch/pytorch/issues/23110
|
||||||
|
@ -55,7 +55,7 @@ This library provides primitives allowing users to create and modify references
|
|||||||
.. autofunction:: rpc_async
|
.. autofunction:: rpc_async
|
||||||
.. autofunction:: remote
|
.. autofunction:: remote
|
||||||
.. autofunction:: get_worker_info
|
.. autofunction:: get_worker_info
|
||||||
.. autofunction:: wait_all_workers
|
.. autofunction:: shutdown
|
||||||
|
|
||||||
Distributed Autograd Framework
|
Distributed Autograd Framework
|
||||||
------------------------------
|
------------------------------
|
||||||
|
@ -1360,7 +1360,7 @@ class DistAutogradTest(RpcAgentTestFixture):
|
|||||||
# receive gradients from the node that received an error (and as a
|
# receive gradients from the node that received an error (and as a
|
||||||
# result it didn't execute the rest of the graph).
|
# result it didn't execute the rest of the graph).
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True):
|
|||||||
# since we need to shutdown the RPC agent. If we don't shutdown the
|
# since we need to shutdown the RPC agent. If we don't shutdown the
|
||||||
# RPC agent, tests would fail since RPC agent threads, locks and
|
# RPC agent, tests would fail since RPC agent threads, locks and
|
||||||
# condition variables are not properly terminated.
|
# condition variables are not properly terminated.
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
return return_value
|
return return_value
|
||||||
|
|
||||||
|
102
test/rpc_test.py
102
test/rpc_test.py
@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
@ -18,6 +19,21 @@ from torch.distributed.rpc.api import _use_rpc_pickler
|
|||||||
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
||||||
from rpc_agent_test_fixture import RpcAgentTestFixture
|
from rpc_agent_test_fixture import RpcAgentTestFixture
|
||||||
|
|
||||||
|
rpc_done = [False, False, False, False]
|
||||||
|
|
||||||
|
# TODO: dedupe this with the code in dist_autograd_test.py.
|
||||||
|
# Send rpc done info and context_id to
|
||||||
|
# dst_rank = (self.rank + rank_distance) % self.world_size
|
||||||
|
# we don't need a lock here since the GIL is held while executing remote
|
||||||
|
# python UDFs, so access is serialized across several workers.
|
||||||
|
def _set_rpc_done(rank_distance):
|
||||||
|
global rpc_done
|
||||||
|
rpc_done[rank_distance] = True
|
||||||
|
|
||||||
|
def _check_rpc_done(rank_distance):
|
||||||
|
while not rpc_done[rank_distance]:
|
||||||
|
# yield control to other threads
|
||||||
|
time.sleep(0)
|
||||||
|
|
||||||
def requires_process_group_agent(message=""):
|
def requires_process_group_agent(message=""):
|
||||||
def decorator(old_func):
|
def decorator(old_func):
|
||||||
@ -127,7 +143,6 @@ def my_tensor_function(a, b):
|
|||||||
return a + b
|
return a + b
|
||||||
|
|
||||||
def my_sleep_func(seconds=1):
|
def my_sleep_func(seconds=1):
|
||||||
import time
|
|
||||||
time.sleep(seconds)
|
time.sleep(seconds)
|
||||||
|
|
||||||
|
|
||||||
@ -346,7 +361,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
rpc_backend_options=self.rpc_backend_options,
|
rpc_backend_options=self.rpc_backend_options,
|
||||||
)
|
)
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_reinit(self):
|
def test_reinit(self):
|
||||||
@ -379,8 +394,9 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
rpc_backend_options=self.rpc_backend_options,
|
rpc_backend_options=self.rpc_backend_options,
|
||||||
)
|
)
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
|
@unittest.skip("test_invalid_names is flaky, see https://github.com/pytorch/pytorch/issues/25912")
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_invalid_names(self):
|
def test_invalid_names(self):
|
||||||
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
||||||
@ -447,8 +463,8 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
|
|
||||||
from torch.distributed.rpc.api import _agent
|
from torch.distributed.rpc.api import _agent
|
||||||
self.assertEqual(_agent, None)
|
self.assertEqual(_agent, None)
|
||||||
# wait_all_workers() should not do anything as _agent is None
|
# shutdown() should not do anything as _agent is None
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
# We need this barrier here because although init_process_group is
|
# We need this barrier here because although init_process_group is
|
||||||
# blocking, it does not guarantee that all ranks are done with
|
# blocking, it does not guarantee that all ranks are done with
|
||||||
# initialization after the call. We did run into issues with it where
|
# initialization after the call. We did run into issues with it where
|
||||||
@ -525,7 +541,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||||
|
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_wait_all_workers(self):
|
def test_shutdown(self):
|
||||||
# Initialize RPC.
|
# Initialize RPC.
|
||||||
rpc.init_rpc(
|
rpc.init_rpc(
|
||||||
name="worker%d" % self.rank,
|
name="worker%d" % self.rank,
|
||||||
@ -543,7 +559,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||||
)
|
)
|
||||||
self.assertEqual(ret, torch.ones(n, n) * 2)
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
||||||
rpc.rpc_sync(
|
rpc.rpc_sync(
|
||||||
@ -552,8 +568,8 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
args=(torch.ones(n, n), torch.ones(n, n)),
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# it's safe to call wait_all_workers() multiple times
|
# it's safe to call shutdown() multiple times
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
@dist_init
|
@dist_init
|
||||||
def test_expected_src(self):
|
def test_expected_src(self):
|
||||||
@ -717,8 +733,6 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
||||||
|
|
||||||
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
||||||
import time
|
|
||||||
|
|
||||||
n = self.rank + 1
|
n = self.rank + 1
|
||||||
dst_rank = n % self.world_size
|
dst_rank = n % self.world_size
|
||||||
futs = []
|
futs = []
|
||||||
@ -1138,7 +1152,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
import torch.distributed.rpc.api as api
|
import torch.distributed.rpc.api as api
|
||||||
api._ignore_rref_leak = True
|
api._ignore_rref_leak = True
|
||||||
|
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_rref_leak(self):
|
def test_rref_leak(self):
|
||||||
@ -1149,6 +1163,68 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
def test_ignore_rref_leak(self):
|
def test_ignore_rref_leak(self):
|
||||||
self._test_rref_leak(ignore_leak=True)
|
self._test_rref_leak(ignore_leak=True)
|
||||||
|
|
||||||
|
@dist_init(setup_rpc=False)
|
||||||
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||||
|
def test_local_shutdown(self):
|
||||||
|
# test that we can start RPC and then immediately locally shutdown
|
||||||
|
# without sending any messages.
|
||||||
|
rpc.init_rpc(
|
||||||
|
name="worker%d" % self.rank,
|
||||||
|
backend=rpc.backend_registry.BackendType[
|
||||||
|
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||||
|
],
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
rpc_backend_options=self.rpc_backend_options,
|
||||||
|
)
|
||||||
|
# pass in graceful=False to ensure that we don't wait for other workers.
|
||||||
|
rpc.shutdown(graceful=False)
|
||||||
|
|
||||||
|
@dist_init(setup_rpc=False)
|
||||||
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||||
|
def test_local_shutdown_with_rpc(self):
|
||||||
|
# test that we can start RPC, send RPCs, and then run local shutdown.
|
||||||
|
rpc.init_rpc(
|
||||||
|
name="worker%d" % self.rank,
|
||||||
|
backend=rpc.backend_registry.BackendType[
|
||||||
|
dist_utils.TEST_CONFIG.rpc_backend_name
|
||||||
|
],
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
rpc_backend_options=self.rpc_backend_options,
|
||||||
|
)
|
||||||
|
n = self.rank + 1
|
||||||
|
dst_rank = n % self.world_size
|
||||||
|
ret = rpc.rpc_sync(
|
||||||
|
"worker{}".format(dst_rank),
|
||||||
|
torch.add,
|
||||||
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
||||||
|
)
|
||||||
|
# wait for RPCs to be done, so that some workers don't try to shut down
|
||||||
|
# too early.
|
||||||
|
rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(1,))
|
||||||
|
_check_rpc_done(1)
|
||||||
|
# pass in graceful=False to ensure that we don't wait for other workers.
|
||||||
|
rpc.shutdown(graceful=False)
|
||||||
|
|
||||||
|
@dist_init(setup_rpc=False)
|
||||||
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||||
|
def test_wait_all_workers_and_shutdown(self):
|
||||||
|
# This tests ensures that both rpc._wait_all_workers() and rpc.shutdown() can be
|
||||||
|
# called without errors being raised due to attempting to shut down
|
||||||
|
# multiple times.
|
||||||
|
rpc.init_rpc(
|
||||||
|
name="worker%d" % self.rank,
|
||||||
|
backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name],
|
||||||
|
rank=self.rank,
|
||||||
|
world_size=self.world_size,
|
||||||
|
rpc_backend_options=self.rpc_backend_options
|
||||||
|
)
|
||||||
|
from torch.distributed.rpc.api import _wait_all_workers
|
||||||
|
# intentional call to internal _wait_all_workers.
|
||||||
|
_wait_all_workers()
|
||||||
|
rpc.shutdown()
|
||||||
|
|
||||||
@dist_init(setup_rpc=False)
|
@dist_init(setup_rpc=False)
|
||||||
def test_get_rpc_timeout(self):
|
def test_get_rpc_timeout(self):
|
||||||
timeout = timedelta(seconds=1)
|
timeout = timedelta(seconds=1)
|
||||||
@ -1167,7 +1243,7 @@ class RpcTest(RpcAgentTestFixture):
|
|||||||
)
|
)
|
||||||
set_timeout = rpc.get_rpc_timeout()
|
set_timeout = rpc.get_rpc_timeout()
|
||||||
self.assertEqual(timeout, set_timeout)
|
self.assertEqual(timeout, set_timeout)
|
||||||
rpc.wait_all_workers()
|
rpc.shutdown()
|
||||||
|
|
||||||
@dist_init
|
@dist_init
|
||||||
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
|
||||||
|
@ -114,9 +114,9 @@ Otherwise, throws an exception.
|
|||||||
return PyRRef::unpickle(t);
|
return PyRRef::unpickle(t);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// future.wait() should not be called after wait_all_workers(), e.g.,
|
// future.wait() should not be called after shutdown(), e.g.,
|
||||||
// pythonRpcHandler is cleaned up in wait_all_workers(), after
|
// pythonRpcHandler is cleaned up in shutdown(), after
|
||||||
// wait_all_workers(), python objects returned from rpc python call can not be
|
// shutdown(), python objects returned from rpc python call can not be
|
||||||
// resolved.
|
// resolved.
|
||||||
auto futureMessage =
|
auto futureMessage =
|
||||||
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
|
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
|
||||||
@ -162,6 +162,10 @@ Otherwise, throws an exception.
|
|||||||
"join",
|
"join",
|
||||||
&ProcessGroupAgent::join,
|
&ProcessGroupAgent::join,
|
||||||
py::call_guard<py::gil_scoped_release>())
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(
|
||||||
|
"shutdown",
|
||||||
|
&ProcessGroupAgent::shutdown,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
.def(
|
.def(
|
||||||
"sync",
|
"sync",
|
||||||
&ProcessGroupAgent::sync,
|
&ProcessGroupAgent::sync,
|
||||||
|
@ -127,7 +127,6 @@ ProcessGroupAgent::ProcessGroupAgent(
|
|||||||
WorkerInfo(std::move(workerName), pg->getRank()),
|
WorkerInfo(std::move(workerName), pg->getRank()),
|
||||||
c10::guts::make_unique<RequestCallbackImpl>(),
|
c10::guts::make_unique<RequestCallbackImpl>(),
|
||||||
rpcTimeout),
|
rpcTimeout),
|
||||||
shutdown_{false},
|
|
||||||
pg_(std::move(pg)),
|
pg_(std::move(pg)),
|
||||||
sendCounts_(pg_->getSize()),
|
sendCounts_(pg_->getSize()),
|
||||||
recvCounts_(pg_->getSize()),
|
recvCounts_(pg_->getSize()),
|
||||||
@ -184,30 +183,12 @@ std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupAgent::join() {
|
void ProcessGroupAgent::join() {
|
||||||
// Every process i sends a SHUTDOWN message to process i + 1. This is
|
|
||||||
// necessary for now because:
|
|
||||||
// 1. There is no abort API for ProcessGroup::recvAnysource yet. We have to
|
|
||||||
// feed it a message or kill the thread.
|
|
||||||
// 2. A GLOO process cannot send message to itself. (there is an ongoing
|
|
||||||
// effort to fix this problem).
|
|
||||||
shutdown_.store(true);
|
|
||||||
sync();
|
sync();
|
||||||
// This is needed in case no futures were created, otherwise the future
|
|
||||||
// timeout watchdog would sleep forever.
|
|
||||||
|
|
||||||
futureTimeoutCV_.notify_one();
|
|
||||||
std::unique_lock<std::mutex> lock(futureMutex_);
|
std::unique_lock<std::mutex> lock(futureMutex_);
|
||||||
futureCV_.wait(
|
futureCV_.wait(
|
||||||
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
|
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
pg_->barrier()->wait();
|
pg_->barrier()->wait();
|
||||||
int dst = (pg_->getRank() + 1) % pg_->getSize();
|
|
||||||
enqueueSend(
|
|
||||||
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
|
|
||||||
threadPool_.waitWorkComplete();
|
|
||||||
listenerThread_.join();
|
|
||||||
futureTimeoutThread_.join();
|
|
||||||
PythonRpcHandler::getInstance().cleanup();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ProcessGroupAgent::hasPendingMessage() {
|
bool ProcessGroupAgent::hasPendingMessage() {
|
||||||
@ -273,14 +254,38 @@ void ProcessGroupAgent::sync() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupAgent::start() {
|
void ProcessGroupAgent::start() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> futureLock{futureMutex_};
|
||||||
|
rpcRunning_.store(true);
|
||||||
|
}
|
||||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||||
futureTimeoutThread_ =
|
futureTimeoutThread_ =
|
||||||
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
|
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessGroupAgent::shutdown() {
|
||||||
|
LOG(INFO) << "Shutting down ProcessGroupAgent.";
|
||||||
|
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||||
|
if (!rpcRunning_.exchange(false)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lock.unlock();
|
||||||
|
futureTimeoutCV_.notify_one();
|
||||||
|
futureTimeoutThread_.join();
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(recvWorkMutex_);
|
||||||
|
if (recvWork_) {
|
||||||
|
recvWork_->abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadPool_.waitWorkComplete();
|
||||||
|
listenerThread_.join();
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
|
||||||
const WorkerInfo& to,
|
const WorkerInfo& to,
|
||||||
Message&& message) {
|
Message&& message) {
|
||||||
|
TORCH_CHECK(rpcRunning_.load(), "ProcessGroupAgent hasn't started.")
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
to.id_ < (worker_id_t)pg_->getSize(),
|
to.id_ < (worker_id_t)pg_->getSize(),
|
||||||
"Destination rank is out of bound, got ",
|
"Destination rank is out of bound, got ",
|
||||||
@ -460,10 +465,19 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupAgent::listenLoop() {
|
void ProcessGroupAgent::listenLoop() {
|
||||||
while (true) {
|
while (rpcRunning_.load()) {
|
||||||
// rank, tensor size, message type
|
// rank, tensor size, message type
|
||||||
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
|
std::vector<torch::Tensor> preamble = {torch::empty({3}, {torch::kInt64})};
|
||||||
pg_->recvAnysource(preamble, pg_->getRank())->wait();
|
auto work = pg_->recvAnysource(preamble, pg_->getRank());
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(recvWorkMutex_);
|
||||||
|
recvWork_ = work;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!rpcRunning_.load() || !work->wait() /* aborted */) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
||||||
|
|
||||||
auto srcRank = preamble_items[0];
|
auto srcRank = preamble_items[0];
|
||||||
@ -487,9 +501,12 @@ void ProcessGroupAgent::listenLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupAgent::pollTimedOutRPCs() {
|
void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||||
while (!shutdown_.load()) {
|
while (true) {
|
||||||
std::chrono::milliseconds sleepTime;
|
|
||||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||||
|
if (!rpcRunning_.load()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::chrono::milliseconds sleepTime;
|
||||||
// Estimate amount of time the first future will time out in, and sleep
|
// Estimate amount of time the first future will time out in, and sleep
|
||||||
// for that long.
|
// for that long.
|
||||||
// if there are no futures or the first future's RPC timeout is set to 0
|
// if there are no futures or the first future's RPC timeout is set to 0
|
||||||
@ -509,7 +526,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
|
|||||||
futureTimeoutCV_.wait_for(lock, sleepTime);
|
futureTimeoutCV_.wait_for(lock, sleepTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (shutdown_.load()) {
|
if (!rpcRunning_.load()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,6 +59,8 @@ class ProcessGroupAgent : public RpcAgent {
|
|||||||
|
|
||||||
void start() override;
|
void start() override;
|
||||||
|
|
||||||
|
void shutdown() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// This method wraps the destination information and the message into a
|
// This method wraps the destination information and the message into a
|
||||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||||
@ -145,10 +147,6 @@ class ProcessGroupAgent : public RpcAgent {
|
|||||||
return ++nextId_;
|
return ++nextId_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// atomic bool indicating if join() has been called and background threads
|
|
||||||
// should shutdown.
|
|
||||||
std::atomic_bool shutdown_;
|
|
||||||
|
|
||||||
std::shared_ptr<c10d::ProcessGroup> pg_;
|
std::shared_ptr<c10d::ProcessGroup> pg_;
|
||||||
// worker name -> rank
|
// worker name -> rank
|
||||||
std::unordered_map<std::string, int> nameMap_;
|
std::unordered_map<std::string, int> nameMap_;
|
||||||
@ -161,12 +159,23 @@ class ProcessGroupAgent : public RpcAgent {
|
|||||||
MessageCounter recvCounts_;
|
MessageCounter recvCounts_;
|
||||||
|
|
||||||
std::atomic<int64_t> nextId_;
|
std::atomic<int64_t> nextId_;
|
||||||
|
// atomic bool indicating if this agent is running. It is set in
|
||||||
|
// ProcessGroupAgent::start and unset in ProcessGroupAgent::shutdown and
|
||||||
|
// ProcessGroupAgent::join. It controls whether several background threads
|
||||||
|
// should be running.
|
||||||
|
// We lock access to this in shutdown() and pollTimedOutRPCs() to prevent race
|
||||||
|
// conditions when notifying condition variables.
|
||||||
|
std::atomic<bool> rpcRunning_{false};
|
||||||
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
||||||
// when using the same tag.
|
// when using the same tag.
|
||||||
std::vector<std::mutex> sendMutexes_;
|
std::vector<std::mutex> sendMutexes_;
|
||||||
std::thread listenerThread_;
|
std::thread listenerThread_;
|
||||||
// A thread to poll existing futures and check for timed out ones.
|
// A thread to poll existing futures and check for timed out ones.
|
||||||
std::thread futureTimeoutThread_;
|
std::thread futureTimeoutThread_;
|
||||||
|
// Lock and shared ptr to currently pending work, set in listenloop() and
|
||||||
|
// interruptible in shutdown().
|
||||||
|
std::mutex recvWorkMutex_;
|
||||||
|
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
|
||||||
// A threadPool that processing both SendWork and RecvWork. There are two
|
// A threadPool that processing both SendWork and RecvWork. There are two
|
||||||
// motivations for adding a ThreadPool:
|
// motivations for adding a ThreadPool:
|
||||||
// (1) RPC serialization/deserialization and processing can be expensive,
|
// (1) RPC serialization/deserialization and processing can be expensive,
|
||||||
|
@ -127,7 +127,11 @@ class TORCH_API RpcAgent {
|
|||||||
virtual void sync() = 0;
|
virtual void sync() = 0;
|
||||||
|
|
||||||
// start accepting requests
|
// start accepting requests
|
||||||
virtual void start() {}
|
virtual void start() = 0;
|
||||||
|
|
||||||
|
// Stop accepting requests and shutdown the RPC framework as soon as possible
|
||||||
|
// by terminating all RPC threads.
|
||||||
|
virtual void shutdown() = 0;
|
||||||
|
|
||||||
// Set the default rpc agent.
|
// Set the default rpc agent.
|
||||||
static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);
|
static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);
|
||||||
|
@ -58,12 +58,34 @@ def _require_initialized(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def wait_all_workers():
|
def _wait_all_workers():
|
||||||
r"""
|
r"""
|
||||||
Block until all local and remote RPC processes reach this method, and then
|
Block until all local and remote RPC processes reach this method and wait
|
||||||
destroy local the RPC agent. Every RPC process must call this method before
|
for all outstanding work to complete. Every RPC process must call this
|
||||||
exit. This should be used to terminate the RPC framework, and there is no
|
method before exit to perform a graceful shutdown. This should be used to
|
||||||
guarantee that the RPC framework will work after this method returns.
|
terminate the RPC framework, and there is no guarantee that the RPC
|
||||||
|
framework will work after this method returns.
|
||||||
|
"""
|
||||||
|
global _agent
|
||||||
|
|
||||||
|
if _agent:
|
||||||
|
_agent.join()
|
||||||
|
|
||||||
|
def shutdown(graceful=True):
|
||||||
|
r"""
|
||||||
|
Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
|
||||||
|
stops the local agent from accepting outstanding requests, and shuts
|
||||||
|
down the RPC framework by terminating all RPC threads. If graceful=True,
|
||||||
|
then this will block until all local and remote RPC processes reach this
|
||||||
|
method and wait for all outstanding work to complete. Otherwise, if
|
||||||
|
graceful=False, then this is a local shutdown, and it does not wait for
|
||||||
|
other RPC processes to reach this method.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
graceful (bool): Whether to do a graceful shutdown or not. If True,
|
||||||
|
this will block until all local and remote RPC
|
||||||
|
processes have reached this method and wait for all
|
||||||
|
outstanding work to complete.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@ -73,27 +95,28 @@ def wait_all_workers():
|
|||||||
>>> # do some work
|
>>> # do some work
|
||||||
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
|
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
|
||||||
>>> # ready to shutdown
|
>>> # ready to shutdown
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
|
|
||||||
On worker 1:
|
On worker 1:
|
||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||||
>>> # wait for worker 0 to finish work, and then shutdown.
|
>>> # wait for worker 0 to finish work, and then shutdown.
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
"""
|
"""
|
||||||
global _agent
|
global _agent
|
||||||
|
|
||||||
if _agent:
|
if _agent:
|
||||||
_agent.join()
|
if graceful:
|
||||||
_agent = None
|
_wait_all_workers()
|
||||||
_destroy_rref_context(_ignore_rref_leak)
|
_destroy_rref_context(_ignore_rref_leak)
|
||||||
# clean up python rpc handler in wait_all_workers(), see comments in
|
_agent.shutdown()
|
||||||
|
# clean up python rpc handler in shutdown(), see comments in
|
||||||
# PythonRpcHandler::cleanup(), call it in python API because the
|
# PythonRpcHandler::cleanup(), call it in python API because the
|
||||||
# cleanup() function has python dependency, it assumes python
|
# cleanup() function has python dependency, it assumes python
|
||||||
# interpreter exists
|
# interpreter exists
|
||||||
_cleanup_python_rpc_handler()
|
_cleanup_python_rpc_handler()
|
||||||
|
_agent = None
|
||||||
|
|
||||||
# TODO: add a context manager to wrap _init_rpc_backend and wait_all_workers
|
# TODO: add a context manager to wrap _init_rpc_backend and shutdown
|
||||||
def _init_rpc_backend(
|
def _init_rpc_backend(
|
||||||
backend=backend_registry.BackendType.PROCESS_GROUP,
|
backend=backend_registry.BackendType.PROCESS_GROUP,
|
||||||
store=None,
|
store=None,
|
||||||
@ -205,12 +228,12 @@ def remote(to, func, args=None, kwargs=None):
|
|||||||
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
|
||||||
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
|
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
|
||||||
>>> x = rref1.to_here() + rref2.to_here()
|
>>> x = rref1.to_here() + rref2.to_here()
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
|
|
||||||
On worker 1:
|
On worker 1:
|
||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
"""
|
"""
|
||||||
qualified_name = torch.jit._find_builtin(func)
|
qualified_name = torch.jit._find_builtin(func)
|
||||||
|
|
||||||
@ -274,12 +297,12 @@ def rpc_sync(to, func, args=None, kwargs=None):
|
|||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||||
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
|
|
||||||
On worker 1:
|
On worker 1:
|
||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
"""
|
"""
|
||||||
fut = _invoke_rpc(to, func, args, kwargs)
|
fut = _invoke_rpc(to, func, args, kwargs)
|
||||||
return fut.wait()
|
return fut.wait()
|
||||||
@ -314,12 +337,12 @@ def rpc_async(to, func, args=None, kwargs=None):
|
|||||||
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
|
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
|
||||||
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
|
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
|
||||||
>>> result = fut1.wait() + fut2.wait()
|
>>> result = fut1.wait() + fut2.wait()
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
|
|
||||||
On worker 1:
|
On worker 1:
|
||||||
>>> import torch.distributed.rpc as rpc
|
>>> import torch.distributed.rpc as rpc
|
||||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||||
>>> rpc.wait_all_workers()
|
>>> rpc.shutdown()
|
||||||
"""
|
"""
|
||||||
fut = _invoke_rpc(to, func, args, kwargs)
|
fut = _invoke_rpc(to, func, args, kwargs)
|
||||||
return fut
|
return fut
|
||||||
|
Reference in New Issue
Block a user