Files
pytorch/test/test_rpc.py
Shen Li 8b349073ce sync and async torch.distributed.rpc for builtin operators (#23228)
Summary:
Features:

* sync and async RPC for builtin operators
* RpcAgent API
* ProcessGroupAgent implementation

Goal:

* have a minimum working and testable RPC implementation
* make sure the RpcAgent API is sufficient for future ThriftAgent and TensorPipeAgent implementation
  * For tensor pipe implementation, it might allocate multiple underlying communication channels with different types, and might also use streaming serialization/deserialization for large tensors. To support this requirement, the current implementation only convert a BuiltinOp into a Message which contains a byte vector and a tensor table. It is up to the RpcAgent implementation to determine how it would like to serialize a Message object.
  * For ThriftAgent, as Thrift has it own request/response matching solution, the Message.id is no longer necessary. Hence the id can be dropped during serialization. All it needs to do is to pass the response Message object to the Future returned by send(...).
* support blocking and non-blocking RequestCallback
  * blocking means the callback won't return before sending out the response
  * non-blocking can be achieved by enqueue the `(from, request, RpcAgent&)` tuple and use a different thread to process them. That is why there is an `RpcAgent&` arg in the param list.

We are not exporting this diff until we finalize distributed autograd design and publish the API review publicly.

https://fb.quip.com/FabTAZKVgQpf

Pull Request resolved: https://github.com/pytorch/pytorch/pull/23228
ghstack-source-id: 87816717

Reviewed By: zhaojuanmao

Differential Revision: D15194693

fbshipit-source-id: 7adb600796613cde6073db6c227451b89940ecaf
2019-08-06 16:03:01 -07:00

115 lines
3.6 KiB
Python

import sys
import torch
import torch.distributed as dist
from common_distributed import MultiProcessTestCase
from common_utils import load_tests, run_tests
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
if not dist.is_available():
print('c10d not available, skipping tests')
sys.exit(0)
def _wrap_with_rpc(func):
def wrapper(self):
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend='gloo', rank=self.rank,
world_size=self.world_size, store=store)
dist.init_rpc('worker%d' % self.rank)
func(self)
dist.join_rpc()
return wrapper
class RpcTest(MultiProcessTestCase):
@property
def world_size(self):
return 4
@_wrap_with_rpc
def test_add(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_scalar_add(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), n))
self.assertEqual(ret, (torch.ones(n, n) + n))
@_wrap_with_rpc
def test_async_add(self):
n = self.rank + 1
dstRank = n % self.world_size
fut = dist.rpc('worker%d' % dstRank,
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
async_call=True)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_nonzero(self):
n = self.rank + 1
dstRank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = dist.rpc('worker%d' % dstRank, torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@_wrap_with_rpc
def test_multi_rpc(self):
dstRank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
@_wrap_with_rpc
def test_sync_rpc(self):
dstRank = (self.rank + 1) % self.world_size
for i in range(20):
dist.sync_rpc()
n = i + self.rank + 1
ret1 = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
dist.sync_rpc()
ret2 = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), 2))
dist.sync_rpc()
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)
@_wrap_with_rpc
def test_join_rpc(self):
n = self.rank + 1
dstRank = n % self.world_size
ret = dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
self.assertEqual(ret, torch.ones(n, n) * 2)
dist.join_rpc()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
dist.rpc('worker%d' % dstRank, torch.add,
args=(torch.ones(n, n), torch.ones(n, n)))
# it's safe to call join_rpc() multiple times
dist.join_rpc()
if __name__ == '__main__':
run_tests()