Files
pytorch/test/rpc_test.py
Pritam Damania 3bccd3fc0d Distributed Autograd - FAST mode backward pass implementation. (#27022)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022

This change implements the "FAST" mode distributed autograd backward
pass as described in https://github.com/pytorch/pytorch/issues/23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926

Test Plan: unit tests.

Differential Revision: D17652615

fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
2019-10-12 09:47:49 -07:00

830 lines
26 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import concurrent.futures
import sys
import unittest
from collections import namedtuple
from unittest import mock
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
from common_utils import load_tests
from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init
from torch.distributed import ProcessGroupAgent
from torch.distributed.rpc import RpcBackend
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
def requires_process_group_agent(func):
from torch.distributed.rpc.api import _agent
return unittest.skipUnless(
isinstance(_agent, ProcessGroupAgent),
"Only ProcessGroupAgent supports global termination detection",
)
VALUE_FUTURE = concurrent.futures.Future()
def stub_init_rpc_backend_handler(self_rank, self_name, init_method):
return mock.Mock() # RpcAgent.
def set_value(value):
VALUE_FUTURE.set_result(value)
# it is used to test python user defined function over rpc
# classes and functions are used to test python user defined class and
# methods over rpc
TensorClass = namedtuple("TensorClass", ["tensors"])
class MyPickleClass:
def __init__(self):
self.t = None
def __getstate__(self):
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
)
return (pickled_python_udf, tensors)
def __setstate__(self, obj):
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
result = python_udf.func(python_udf.args[0], python_udf.args[1])
self.t = result
def set(self, val):
self.t = val
class MyClass:
def __init__(self, a):
self.a = a
def my_instance_method(self, b):
return self.a + b
@classmethod
def my_class_method(cls, d, e):
return d + e
@staticmethod
def my_static_method(f):
return f > 10
def run_nested_pickle(pickle_cls_instance, tensor):
return pickle_cls_instance.t + tensor
def build_complex_tensors():
a = torch.ones(3, 3)
b = [a, a]
c = [b, b]
d = [a, b]
e = {a: d}
return [a, b, c, d, e]
def my_function(a, b, c):
return a + b + c
def my_tensor_function(a, b):
return a + b
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
res = list_input[0]
for t in list_input:
res += t
for k, v in dict_input.items():
res += v
complex_tensors = tensor_class_input.tensors
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
def my_rref_function(rref_a, rref_b):
return rref_a.to_here() + rref_b.to_here()
def no_result():
print("do nothing")
def nested_rpc(dst):
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
def multi_layer_nested_async_rpc(dst, world_size, ttl):
# this method returns immediately without blocking the callee, but will
# generate additional requests.
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
rpc.rpc_async(
current_dst,
multi_layer_nested_async_rpc,
args=(next_dst, world_size, ttl - 1),
)
return 0
def nested_rref(dst):
return (
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
)
def nested_remote(dst):
rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
return rref.to_here()
def rref_forward_chain(dst, world_size, rref, ttl):
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
ret_rref = dist.remote(
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
)
return [ret_rref]
else:
return rref.to_here()
def rpc_return_rref(dst):
return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
def light_rpc():
return 0
def heavy_rpc(tensor):
for i in range(1, 100):
tensor *= i
tensor /= i + 1
return 0
def raise_func():
raise ValueError("Expected error")
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@unittest.skipIf(
sys.version_info < (3, 0),
"Pytorch distributed rpc package " "does not support python2",
)
class RpcTest(object):
@property
def world_size(self):
return 4
@property
def init_method(self):
return INIT_METHOD_TEMPLATE.format(
file_name=self.file_name, rank=self.rank, world_size=self.world_size
)
@dist_init
def test_worker_id(self):
n = self.rank + 1
peer_rank = n % self.world_size
self_worker_info = rpc.get_worker_info()
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
@dist_init
def test_self_add(self):
self_worker_info = rpc.get_worker_info()
self_worker_name = "worker{}".format(self.rank)
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc.api, "_init_rpc_agent")
def test_register_rpc_backend_and_init_rpc_backend(
self, mock_init_rref_context, mock_dist_autograd_init
):
backend_name = "stub_backend"
rpc.register_backend(
backend_name, stub_init_rpc_backend_handler
)
rpc.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
@unittest.skipIf(
TEST_CONFIG.rpc_backend != RpcBackend.PROCESS_GROUP,
"PROCESS_GROUP rpc backend specific test, skip",
)
def test_duplicate_name(self):
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
with self.assertRaisesRegex(RuntimeError, "is not unique"):
rpc.init_model_parallel(
self_name="duplicate_name",
backend=TEST_CONFIG.rpc_backend,
self_rank=self.rank,
init_method=self.init_method,
)
rpc.join_rpc()
def test_reinit(self):
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend=TEST_CONFIG.rpc_backend,
self_rank=self.rank,
init_method=self.init_method,
)
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend=TEST_CONFIG.rpc_backend,
self_rank=self.rank,
init_method=self.init_method,
)
rpc.join_rpc()
def test_init_invalid_backend(self):
with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"):
rpc.init_model_parallel(
self_name="worker{}".format(self.rank),
backend="invalid",
self_rank=self.rank,
init_method=self.init_method,
)
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/25912")
def test_invalid_names(self):
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
rpc.init_model_parallel(self_name="abc*")
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
rpc.init_model_parallel(self_name=" ")
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
rpc.init_model_parallel(self_name="")
# If the number in the message does not match, it is likely that the
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
rpc.init_model_parallel(
self_name="".join(["a" for _ in range(500)]),
backend=TEST_CONFIG.rpc_backend,
self_rank=self.rank,
init_method=self.init_method,
)
rpc.join_rpc()
@dist_init
def test_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init
def test_add_with_id(self):
n = self.rank + 1
dst_rank = n % self.world_size
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
ret = rpc.rpc_sync(
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init
def test_scalar_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
)
self.assertEqual(ret, (torch.ones(n, n) + n))
@dist_init
def test_async_add(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
@dist_init
def test_nonzero(self):
n = self.rank + 1
dst_rank = n % self.world_size
x = torch.ones(self.world_size, self.world_size)
x[self.rank][self.rank] = 0
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
self.assertEqual(ret, x.nonzero())
@dist_init
def test_multi_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
n = i + self.rank + 1
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@dist_init
def test_sync_rpc(self):
dst_rank = (self.rank + 1) % self.world_size
for i in range(20):
rpc.sync_rpc()
n = i + self.rank + 1
ret1 = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
rpc.sync_rpc()
ret2 = rpc.rpc_sync(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
rpc.sync_rpc()
self.assertEqual(ret1, torch.ones(n, n) * 2)
self.assertEqual(ret2, torch.ones(n, n) * 3)
@dist_init
def test_join_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, torch.ones(n, n) * 2)
rpc.join_rpc()
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
rpc.rpc_sync(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
# it's safe to call join_rpc() multiple times
rpc.join_rpc()
@dist_init
def test_expected_src(self):
dst_rank = (self.rank + 1) % self.world_size
expected_src_rank = (self.rank - 1) % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
value = VALUE_FUTURE.result()
self.assertEqual(value, expected_src_rank)
@dist_init
def test_py_built_in(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
self.assertEqual(ret, min(n, n + 1, n + 2))
@dist_init
def test_py_user_defined(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(ret, my_function(n, n + 1, n + 2))
@dist_init
def test_py_class_constructor(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
self.assertEqual(ret.a, n)
@dist_init
def test_py_class_instance_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
)
self.assertEqual(ret, MyClass(2).my_instance_method(n))
@dist_init
def test_py_class_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
)
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
@dist_init
def test_py_class_static_method(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
)
self.assertEqual(ret, MyClass.my_static_method(n + 10))
@dist_init
def test_py_multi_async_call(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
@dist_init
def test_py_no_return_result(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
self.assertEqual(ret, no_result())
@dist_init
def test_py_tensors(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
@dist_init
def test_py_tensors_multi_async_call(self):
futs = []
n = self.rank + 1
dst_rank = n % self.world_size
for i in range(100):
fut = rpc.rpc_async(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(i, i), torch.ones(i, i)),
)
futs.append(fut)
j = 0
for fut in futs:
self.assertEqual(
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
)
j += 1
@dist_init
def test_py_tensors_in_container(self):
n = self.rank + 1
dst_rank = n % self.world_size
a = [torch.ones(n, n), torch.ones(n, n)]
b = TensorClass(build_complex_tensors())
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
ret = rpc.rpc_sync(
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
)
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
@dist_init
def test_py_nested_pickle(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
run_nested_pickle,
args=(MyPickleClass(), torch.ones(2, 2)),
)
m = MyPickleClass()
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
@dist_init
def test_py_function_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
with self.assertRaisesRegex(Exception, "TypeError"):
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
@dist_init
def test_py_raise_in_user_func(self):
n = self.rank + 1
dst_rank = n % self.world_size
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
with self.assertRaisesRegex(Exception, "ValueError"):
fut.wait()
@dist_init
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
nested_rpc,
args=("worker{}".format(self.rank),),
)
self.assertEqual(ret, torch.ones(2, 2) + 1)
def _stress_test_rpc(self, f, repeat=1000, args=()):
import time
n = self.rank + 1
dst_rank = n % self.world_size
futs = []
tik = time.time()
for _ in range(repeat):
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
futs.append(fut)
for fut in futs:
self.assertEqual(fut.wait(), 0)
tok = time.time()
print(
"Rank {} finished testing {} {} times in {} seconds.".format(
self.rank, f.__name__, repeat, tok - tik
)
)
@dist_init
def test_stress_light_rpc(self):
self._stress_test_rpc(light_rpc)
@dist_init
def test_stress_heavy_rpc(self):
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
@dist_init
def test_builtin_remote_ret(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = rpc.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
m = 10
n = self.rank + 1
dst_rank = n % self.world_size
rrefs = []
expected = []
for i in range(m):
n = n + i
rrefs.append(
rpc.remote(
"worker{}".format(dst_rank),
fn,
args=args_fn(n),
kwargs=kwargs_fn(n),
)
)
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
for i in range(m):
self.assertEqual(rrefs[i].to_here(), expected[i])
@dist_init
@requires_process_group_agent
def test_multi_builtin_remote_ret(self):
def args_fn(n):
return (torch.ones(n, n), torch.ones(n, n))
self._test_multi_remote_call(torch.add, args_fn=args_fn)
@dist_init
@requires_process_group_agent
def test_py_udf_remote(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
@dist_init
@requires_process_group_agent
def test_multi_py_udf_remote(self):
def kwargs_fn(n):
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
@dist_init
@requires_process_group_agent
def test_py_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
rref_b = dist.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
rref_c = dist.remote(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_py_rref_args_user_share(self):
n = self.rank + 1
owner_rank = n % self.world_size
user_rank = (n + 1) % self.world_size
rref_a = dist.remote(
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
)
rref_b = dist.remote(
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
)
rref_c = dist.remote(
"worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_py_rpc_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
)
rref_b = dist.remote(
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
)
c = dist.rpc_sync(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(c, torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_nested_remote(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank1),
nested_remote,
args=("worker{}".format(dst_rank2),),
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
@dist_init
@requires_process_group_agent
def test_nested_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref_of_rrefs = dist.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),),
)
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
@requires_process_group_agent
def test_nested_rref_stress(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
all_rrefs = []
for _ in range(20):
all_rrefs.append(
dist.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),),
)
)
for i in range(20):
rref_of_rrefs = all_rrefs[i]
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
@requires_process_group_agent
def test_multi_layer_nested_async_rpc(self):
# This test will exit right away, but there will be a chain of async
# RPCs. The termination algorithm should detect those messages properly.
# Otherwise, some peer could exit early, leaving others to timeout
# errors or connection closed errors.
ttl = 20
n = self.rank + 1
dst_rank = n % self.world_size
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
@dist_init
@requires_process_group_agent
def test_remote_with_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote("worker{}".format(dst_rank), raise_func)
with self.assertRaisesRegex(Exception, "ValueError"):
rref.to_here()
@dist_init
@requires_process_group_agent
def test_rpc_return_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = dist.rpc_sync(
"worker{}".format(dst_rank1),
rpc_return_rref,
args=("worker{}".format(dst_rank2),),
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
@dist_init
@requires_process_group_agent
def test_rref_forward_chain(self):
ttl = 8
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
for i in range(ttl):
self.assertEqual(len(ret_rref), 1)
ret_rref = ret_rref[0].to_here()
ret = ret_rref
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
@dist_init
@requires_process_group_agent
def test_remote_same_worker(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
)
rref_b = dist.remote(
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
)
rref_c = dist.remote(
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)