mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022 This change implements the "FAST" mode distributed autograd backward pass as described in https://github.com/pytorch/pytorch/issues/23110. At a high level the backward pass works as follows: 1. We start by computing dependencies on the node that calls `torch.distributed.backward`. 2. This node computes the dependencies starting from the root nodes provided in the backward call and all the 'send' functions present in the current autograd context. The "FAST" mode assumes all 'send' functions are part of the autograd computation. 3. Once the dependency computation is done, the distributed autograd engine calls the local autograd engine to execute the autograd graph. Note that the autograd graph on a single node is not necessarily connected because of inter-node communication. As a result, we have special handling to ensure the local autograd engine ensures we execute the entire graph starting from the provided roots and all 'send' functions on the node. 4. When the local autograd engine hits a 'recv' function, it performs an async RPC to send the gradients over to the appropriate node and stores a future in the autograd context to keep track of this RPC. 5. On the destination node, the appropriate 'send' function is looked up and enqueued on the local autograd engine. If this is the first time the node is hearing about this autograd context id on the backward pass, then the node computes dependencies for the local autograd engine. 6. As part of compute dependencies, the distributed autograd engine discovers all leaf nodes and ensures those are passed as 'outputs' to the local autograd engine. This avoids running the 'AccumulateGrad' function. 7. The gradients computed for the leaf nodes are then actually accumulated in `DistAutogradContext` for the appropriate autograd context id. 8. The distributed autograd engine waits for the local autograd engine to complete and also waits for all the 'Futures' (stored in 4.) for respective RPCs to finish. We have made the following changes to the local autograd engine for this purpose: 1. Expose GraphTask and NodeTask so that the distributed autograd engine can use them. 2. Expose a `execute_with_graph_task` API which gives the distributed engine to build a GraphTask and pass it to the local autograd engine. 3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build a `NodeTask` for a 'send' function and enqueue it on the local autograd engine. In addition to this a few general improvements: 1. Added a `PropagateGradients` RPC call for the 'recv' function to pass gradients to the appropriate node during the backward pass. 2. Use IValues as much as possible in serialization for RpcWithAutograd. 3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate exception instead of just returning the message. This is inline with what most Future.wait() APIs do. 4. Added a `get_gradients(context_id)` API which allows users to retrieve a map from Tensor to respective gradient for the provided context_id on the local node. ghstack-source-id: 91794926 Test Plan: unit tests. Differential Revision: D17652615 fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
830 lines
26 KiB
Python
830 lines
26 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import concurrent.futures
|
|
import sys
|
|
import unittest
|
|
from collections import namedtuple
|
|
from unittest import mock
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.rpc as rpc
|
|
from common_utils import load_tests
|
|
from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init
|
|
from torch.distributed import ProcessGroupAgent
|
|
from torch.distributed.rpc import RpcBackend
|
|
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler
|
|
|
|
|
|
def requires_process_group_agent(func):
|
|
from torch.distributed.rpc.api import _agent
|
|
|
|
return unittest.skipUnless(
|
|
isinstance(_agent, ProcessGroupAgent),
|
|
"Only ProcessGroupAgent supports global termination detection",
|
|
)
|
|
|
|
|
|
VALUE_FUTURE = concurrent.futures.Future()
|
|
|
|
|
|
def stub_init_rpc_backend_handler(self_rank, self_name, init_method):
|
|
return mock.Mock() # RpcAgent.
|
|
|
|
|
|
def set_value(value):
|
|
VALUE_FUTURE.set_result(value)
|
|
|
|
|
|
# it is used to test python user defined function over rpc
|
|
# classes and functions are used to test python user defined class and
|
|
# methods over rpc
|
|
TensorClass = namedtuple("TensorClass", ["tensors"])
|
|
|
|
|
|
class MyPickleClass:
|
|
def __init__(self):
|
|
self.t = None
|
|
|
|
def __getstate__(self):
|
|
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
|
|
PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None)
|
|
)
|
|
return (pickled_python_udf, tensors)
|
|
|
|
def __setstate__(self, obj):
|
|
python_udf = _internal_rpc_pickler.deserialize(obj[0], obj[1])
|
|
result = python_udf.func(python_udf.args[0], python_udf.args[1])
|
|
self.t = result
|
|
|
|
def set(self, val):
|
|
self.t = val
|
|
|
|
|
|
class MyClass:
|
|
def __init__(self, a):
|
|
self.a = a
|
|
|
|
def my_instance_method(self, b):
|
|
return self.a + b
|
|
|
|
@classmethod
|
|
def my_class_method(cls, d, e):
|
|
return d + e
|
|
|
|
@staticmethod
|
|
def my_static_method(f):
|
|
return f > 10
|
|
|
|
|
|
def run_nested_pickle(pickle_cls_instance, tensor):
|
|
return pickle_cls_instance.t + tensor
|
|
|
|
|
|
def build_complex_tensors():
|
|
a = torch.ones(3, 3)
|
|
b = [a, a]
|
|
c = [b, b]
|
|
d = [a, b]
|
|
e = {a: d}
|
|
return [a, b, c, d, e]
|
|
|
|
|
|
def my_function(a, b, c):
|
|
return a + b + c
|
|
|
|
|
|
def my_tensor_function(a, b):
|
|
return a + b
|
|
|
|
|
|
def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
|
|
res = list_input[0]
|
|
for t in list_input:
|
|
res += t
|
|
for k, v in dict_input.items():
|
|
res += v
|
|
complex_tensors = tensor_class_input.tensors
|
|
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
|
|
|
|
|
|
def my_rref_function(rref_a, rref_b):
|
|
return rref_a.to_here() + rref_b.to_here()
|
|
|
|
|
|
def no_result():
|
|
print("do nothing")
|
|
|
|
|
|
def nested_rpc(dst):
|
|
return rpc.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def multi_layer_nested_async_rpc(dst, world_size, ttl):
|
|
# this method returns immediately without blocking the callee, but will
|
|
# generate additional requests.
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
rpc.rpc_async(
|
|
current_dst,
|
|
multi_layer_nested_async_rpc,
|
|
args=(next_dst, world_size, ttl - 1),
|
|
)
|
|
return 0
|
|
|
|
|
|
def nested_rref(dst):
|
|
return (
|
|
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
|
|
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2)),
|
|
)
|
|
|
|
|
|
def nested_remote(dst):
|
|
rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
|
|
return rref.to_here()
|
|
|
|
|
|
def rref_forward_chain(dst, world_size, rref, ttl):
|
|
if ttl > 0:
|
|
current_dst = "worker{}".format(dst)
|
|
next_dst = (dst + 1) % world_size
|
|
ret_rref = dist.remote(
|
|
current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1)
|
|
)
|
|
return [ret_rref]
|
|
else:
|
|
return rref.to_here()
|
|
|
|
|
|
def rpc_return_rref(dst):
|
|
return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
|
|
def light_rpc():
|
|
return 0
|
|
|
|
|
|
def heavy_rpc(tensor):
|
|
for i in range(1, 100):
|
|
tensor *= i
|
|
tensor /= i + 1
|
|
return 0
|
|
|
|
|
|
def raise_func():
|
|
raise ValueError("Expected error")
|
|
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
|
|
@unittest.skipIf(
|
|
sys.version_info < (3, 0),
|
|
"Pytorch distributed rpc package " "does not support python2",
|
|
)
|
|
class RpcTest(object):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@property
|
|
def init_method(self):
|
|
return INIT_METHOD_TEMPLATE.format(
|
|
file_name=self.file_name, rank=self.rank, world_size=self.world_size
|
|
)
|
|
|
|
@dist_init
|
|
def test_worker_id(self):
|
|
n = self.rank + 1
|
|
peer_rank = n % self.world_size
|
|
self_worker_info = rpc.get_worker_info()
|
|
peer_worker_info = rpc.get_worker_info("worker{}".format(peer_rank))
|
|
|
|
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
|
|
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
|
|
unknown_worker_id = rpc.get_worker_info("WorkerUnknown")
|
|
|
|
@dist_init
|
|
def test_self_add(self):
|
|
self_worker_info = rpc.get_worker_info()
|
|
self_worker_name = "worker{}".format(self.rank)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "does not support making RPC calls to self"
|
|
):
|
|
rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "does not support making RPC calls to self"
|
|
):
|
|
rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
|
|
|
|
@mock.patch.object(torch.distributed.autograd, "_init")
|
|
@mock.patch.object(torch.distributed.rpc.api, "_init_rpc_agent")
|
|
def test_register_rpc_backend_and_init_rpc_backend(
|
|
self, mock_init_rref_context, mock_dist_autograd_init
|
|
):
|
|
backend_name = "stub_backend"
|
|
rpc.register_backend(
|
|
backend_name, stub_init_rpc_backend_handler
|
|
)
|
|
rpc.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1)
|
|
|
|
@unittest.skipIf(
|
|
TEST_CONFIG.rpc_backend != RpcBackend.PROCESS_GROUP,
|
|
"PROCESS_GROUP rpc backend specific test, skip",
|
|
)
|
|
def test_duplicate_name(self):
|
|
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
|
|
with self.assertRaisesRegex(RuntimeError, "is not unique"):
|
|
rpc.init_model_parallel(
|
|
self_name="duplicate_name",
|
|
backend=TEST_CONFIG.rpc_backend,
|
|
self_rank=self.rank,
|
|
init_method=self.init_method,
|
|
)
|
|
rpc.join_rpc()
|
|
|
|
def test_reinit(self):
|
|
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
|
|
rpc.init_model_parallel(
|
|
self_name="worker{}".format(self.rank),
|
|
backend=TEST_CONFIG.rpc_backend,
|
|
self_rank=self.rank,
|
|
init_method=self.init_method,
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "is already initialized"):
|
|
rpc.init_model_parallel(
|
|
self_name="worker{}".format(self.rank),
|
|
backend=TEST_CONFIG.rpc_backend,
|
|
self_rank=self.rank,
|
|
init_method=self.init_method,
|
|
)
|
|
rpc.join_rpc()
|
|
|
|
def test_init_invalid_backend(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Unrecognized RPC backend"):
|
|
rpc.init_model_parallel(
|
|
self_name="worker{}".format(self.rank),
|
|
backend="invalid",
|
|
self_rank=self.rank,
|
|
init_method=self.init_method,
|
|
)
|
|
|
|
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/25912")
|
|
def test_invalid_names(self):
|
|
dist.init_process_group(backend=dist.Backend.GLOO, init_method=self.init_method)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
rpc.init_model_parallel(self_name="abc*")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
|
|
rpc.init_model_parallel(self_name=" ")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
|
|
rpc.init_model_parallel(self_name="")
|
|
|
|
# If the number in the message does not match, it is likely that the
|
|
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
|
|
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
|
|
rpc.init_model_parallel(
|
|
self_name="".join(["a" for _ in range(500)]),
|
|
backend=TEST_CONFIG.rpc_backend,
|
|
self_rank=self.rank,
|
|
init_method=self.init_method,
|
|
)
|
|
rpc.join_rpc()
|
|
|
|
@dist_init
|
|
def test_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_add_with_id(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
workder_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
|
|
ret = rpc.rpc_sync(
|
|
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_scalar_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), n)
|
|
)
|
|
self.assertEqual(ret, (torch.ones(n, n) + n))
|
|
|
|
@dist_init
|
|
def test_async_add(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(fut.wait(), torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_nonzero(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
x = torch.ones(self.world_size, self.world_size)
|
|
x[self.rank][self.rank] = 0
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,))
|
|
self.assertEqual(ret, x.nonzero())
|
|
|
|
@dist_init
|
|
def test_multi_rpc(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
n = i + self.rank + 1
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
|
|
@dist_init
|
|
def test_sync_rpc(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
for i in range(20):
|
|
rpc.sync_rpc()
|
|
n = i + self.rank + 1
|
|
ret1 = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
rpc.sync_rpc()
|
|
ret2 = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rpc.sync_rpc()
|
|
self.assertEqual(ret1, torch.ones(n, n) * 2)
|
|
self.assertEqual(ret2, torch.ones(n, n) * 3)
|
|
|
|
@dist_init
|
|
def test_join_rpc(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, torch.ones(n, n) * 2)
|
|
rpc.join_rpc()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "^RPC has not been initialized"):
|
|
rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
|
|
# it's safe to call join_rpc() multiple times
|
|
rpc.join_rpc()
|
|
|
|
@dist_init
|
|
def test_expected_src(self):
|
|
dst_rank = (self.rank + 1) % self.world_size
|
|
expected_src_rank = (self.rank - 1) % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,))
|
|
value = VALUE_FUTURE.result()
|
|
self.assertEqual(value, expected_src_rank)
|
|
|
|
@dist_init
|
|
def test_py_built_in(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(ret, min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_user_defined(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(ret, my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_class_constructor(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,))
|
|
self.assertEqual(ret.a, n)
|
|
|
|
@dist_init
|
|
def test_py_class_instance_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass(2).my_instance_method, args=(n,)
|
|
)
|
|
self.assertEqual(ret, MyClass(2).my_instance_method(n))
|
|
|
|
@dist_init
|
|
def test_py_class_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_class_method, args=(n, n + 1)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_class_method(n, n + 1))
|
|
|
|
@dist_init
|
|
def test_py_class_static_method(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), MyClass.my_static_method, args=(n + 10,)
|
|
)
|
|
self.assertEqual(ret, MyClass.my_static_method(n + 10))
|
|
|
|
@dist_init
|
|
def test_py_multi_async_call(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
dst_worker_info = rpc.get_worker_info("worker{}".format(dst_rank))
|
|
fut1 = rpc.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,))
|
|
fut2 = rpc.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2))
|
|
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
|
|
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
def test_py_no_return_result(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result)
|
|
self.assertEqual(ret, no_result())
|
|
|
|
@dist_init
|
|
def test_py_tensors(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n)))
|
|
|
|
@dist_init
|
|
def test_py_tensors_multi_async_call(self):
|
|
futs = []
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
for i in range(100):
|
|
fut = rpc.rpc_async(
|
|
"worker{}".format(dst_rank),
|
|
my_tensor_function,
|
|
args=(torch.ones(i, i), torch.ones(i, i)),
|
|
)
|
|
futs.append(fut)
|
|
|
|
j = 0
|
|
for fut in futs:
|
|
self.assertEqual(
|
|
fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j))
|
|
)
|
|
j += 1
|
|
|
|
@dist_init
|
|
def test_py_tensors_in_container(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
a = [torch.ones(n, n), torch.ones(n, n)]
|
|
b = TensorClass(build_complex_tensors())
|
|
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c)
|
|
)
|
|
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
|
|
|
|
@dist_init
|
|
def test_py_nested_pickle(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
run_nested_pickle,
|
|
args=(MyPickleClass(), torch.ones(2, 2)),
|
|
)
|
|
|
|
m = MyPickleClass()
|
|
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
|
|
self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2)))
|
|
|
|
@dist_init
|
|
def test_py_function_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
with self.assertRaisesRegex(Exception, "TypeError"):
|
|
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
|
|
|
|
@dist_init
|
|
def test_py_raise_in_user_func(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
fut.wait()
|
|
|
|
@dist_init
|
|
def test_nested_rpc(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
ret = rpc.rpc_sync(
|
|
"worker{}".format(dst_rank),
|
|
nested_rpc,
|
|
args=("worker{}".format(self.rank),),
|
|
)
|
|
self.assertEqual(ret, torch.ones(2, 2) + 1)
|
|
|
|
def _stress_test_rpc(self, f, repeat=1000, args=()):
|
|
import time
|
|
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
futs = []
|
|
tik = time.time()
|
|
for _ in range(repeat):
|
|
fut = rpc.rpc_async("worker{}".format(dst_rank), f, args=args)
|
|
futs.append(fut)
|
|
|
|
for fut in futs:
|
|
self.assertEqual(fut.wait(), 0)
|
|
tok = time.time()
|
|
print(
|
|
"Rank {} finished testing {} {} times in {} seconds.".format(
|
|
self.rank, f.__name__, repeat, tok - tik
|
|
)
|
|
)
|
|
|
|
@dist_init
|
|
def test_stress_light_rpc(self):
|
|
self._stress_test_rpc(light_rpc)
|
|
|
|
@dist_init
|
|
def test_stress_heavy_rpc(self):
|
|
self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),))
|
|
|
|
@dist_init
|
|
def test_builtin_remote_ret(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
torch.add,
|
|
args=(torch.ones(n, n), torch.ones(n, n)),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
|
|
|
|
def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}):
|
|
m = 10
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rrefs = []
|
|
expected = []
|
|
for i in range(m):
|
|
n = n + i
|
|
rrefs.append(
|
|
rpc.remote(
|
|
"worker{}".format(dst_rank),
|
|
fn,
|
|
args=args_fn(n),
|
|
kwargs=kwargs_fn(n),
|
|
)
|
|
)
|
|
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
|
|
|
|
for i in range(m):
|
|
self.assertEqual(rrefs[i].to_here(), expected[i])
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_multi_builtin_remote_ret(self):
|
|
def args_fn(n):
|
|
return (torch.ones(n, n), torch.ones(n, n))
|
|
|
|
self._test_multi_remote_call(torch.add, args_fn=args_fn)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_py_udf_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = dist.remote(
|
|
"worker{}".format(dst_rank),
|
|
my_function,
|
|
kwargs={"a": n, "b": n + 1, "c": n + 2},
|
|
)
|
|
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_multi_py_udf_remote(self):
|
|
def kwargs_fn(n):
|
|
return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)}
|
|
|
|
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_py_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = dist.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = dist.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = dist.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_py_rref_args_user_share(self):
|
|
n = self.rank + 1
|
|
owner_rank = n % self.world_size
|
|
user_rank = (n + 1) % self.world_size
|
|
rref_a = dist.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = dist.remote(
|
|
"worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
rref_c = dist.remote(
|
|
"worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_py_rpc_rref_args(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = dist.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0)
|
|
)
|
|
rref_b = dist.remote(
|
|
"worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0)
|
|
)
|
|
|
|
c = dist.rpc_sync(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
|
|
self.assertEqual(c, torch.ones(n, n) + 4)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_nested_remote(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref = dist.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_remote,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_nested_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref_of_rrefs = dist.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_nested_rref_stress(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
all_rrefs = []
|
|
for _ in range(20):
|
|
all_rrefs.append(
|
|
dist.remote(
|
|
"worker{}".format(dst_rank1),
|
|
nested_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
)
|
|
|
|
for i in range(20):
|
|
rref_of_rrefs = all_rrefs[i]
|
|
rrefs = rref_of_rrefs.to_here()
|
|
self.assertEqual(len(rrefs), 2)
|
|
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
|
|
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_multi_layer_nested_async_rpc(self):
|
|
# This test will exit right away, but there will be a chain of async
|
|
# RPCs. The termination algorithm should detect those messages properly.
|
|
# Otherwise, some peer could exit early, leaving others to timeout
|
|
# errors or connection closed errors.
|
|
ttl = 20
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_remote_with_exception(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref = dist.remote("worker{}".format(dst_rank), raise_func)
|
|
with self.assertRaisesRegex(Exception, "ValueError"):
|
|
rref.to_here()
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_rpc_return_rref(self):
|
|
n = self.rank + 1
|
|
dst_rank1 = n % self.world_size
|
|
dst_rank2 = (n + 1) % self.world_size
|
|
rref = dist.rpc_sync(
|
|
"worker{}".format(dst_rank1),
|
|
rpc_return_rref,
|
|
args=("worker{}".format(dst_rank2),),
|
|
)
|
|
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_rref_forward_chain(self):
|
|
ttl = 8
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
|
|
rref = dist.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
|
|
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
|
|
|
|
for i in range(ttl):
|
|
self.assertEqual(len(ret_rref), 1)
|
|
ret_rref = ret_rref[0].to_here()
|
|
|
|
ret = ret_rref
|
|
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
|
|
|
|
@dist_init
|
|
@requires_process_group_agent
|
|
def test_remote_same_worker(self):
|
|
n = self.rank + 1
|
|
dst_rank = n % self.world_size
|
|
rref_a = dist.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2)
|
|
)
|
|
rref_b = dist.remote(
|
|
"worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1)
|
|
)
|
|
rref_c = dist.remote(
|
|
"worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b)
|
|
)
|
|
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
|