mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 09:34:57 +08:00
Summary: Features: * sync and async RPC for builtin operators * RpcAgent API * ProcessGroupAgent implementation Goal: * have a minimum working and testable RPC implementation * make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object. * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...). * support blocking and non-blocking RequestCallback * blocking means the callback won't return before sending out the response * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list. We are not exporting this diff until we finalize distributed autograd design and publish the API review publicly. https://fb.quip.com/FabTAZKVgQpf Pull Request resolved: https://github.com/pytorch/pytorch/pull/23228 ghstack-source-id: 87816717 Reviewed By: zhaojuanmao Differential Revision: D15194693 fbshipit-source-id: 7adb600796613cde6073db6c227451b89940ecaf
115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
import sys
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from common_distributed import MultiProcessTestCase
|
|
from common_utils import load_tests, run_tests
|
|
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
if not dist.is_available():
|
|
print('c10d not available, skipping tests')
|
|
sys.exit(0)
|
|
|
|
|
|
def _wrap_with_rpc(func):
|
|
def wrapper(self):
|
|
store = dist.FileStore(self.file.name, self.world_size)
|
|
dist.init_process_group(backend='gloo', rank=self.rank,
|
|
world_size=self.world_size, store=store)
|
|
dist.init_rpc('worker%d' % self.rank)
|
|
func(self)
|
|
dist.join_rpc()
|
|
|
|
return wrapper
|
|
|
|
|
|
class RpcTest(MultiProcessTestCase):
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@_wrap_with_rpc
|
|
def test_add(self):
|
|
n = self.rank + 1
|
|
dstRank = n % self.world_size
|
|
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)))
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@_wrap_with_rpc
|
|
def test_scalar_add(self):
|
|
n = self.rank + 1
|
|
dstRank = n % self.world_size
|
|
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), n))
|
|
self.assertEqual(ret, (torch.ones(n, n) + n))
|
|
|
|
@_wrap_with_rpc
|
|
def test_async_add(self):
|
|
n = self.rank + 1
|
|
dstRank = n % self.world_size
|
|
fut = dist.rpc('worker%d' % dstRank,
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
async_call=True)
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
|
|
|
@_wrap_with_rpc
|
|
def test_nonzero(self):
|
|
n = self.rank + 1
|
|
dstRank = n % self.world_size
|
|
x = torch.ones(self.world_size, self.world_size)
|
|
x[self.rank][self.rank] = 0
|
|
ret = dist.rpc('worker%d' % dstRank, torch.nonzero, args=(x,))
|
|
self.assertEqual(ret, x.nonzero())
|
|
|
|
@_wrap_with_rpc
|
|
def test_multi_rpc(self):
|
|
dstRank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
n = i + self.rank + 1
|
|
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)))
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@_wrap_with_rpc
|
|
def test_sync_rpc(self):
|
|
dstRank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
dist.sync_rpc()
|
|
n = i + self.rank + 1
|
|
ret1 = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)))
|
|
dist.sync_rpc()
|
|
ret2 = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), 2))
|
|
dist.sync_rpc()
|
|
self.assertEqual(ret1, torch.ones(n, n) * 2)
|
|
self.assertEqual(ret2, torch.ones(n, n) * 3)
|
|
|
|
@_wrap_with_rpc
|
|
def test_join_rpc(self):
|
|
n = self.rank + 1
|
|
dstRank = n % self.world_size
|
|
ret = dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)))
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
dist.join_rpc()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
|
dist.rpc('worker%d' % dstRank, torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)))
|
|
|
|
# it's safe to call join_rpc() multiple times
|
|
dist.join_rpc()
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|