Add Python RRef as args and return value (#25499)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499

See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
This commit is contained in:
Shen Li
2019-10-03 17:45:30 -07:00
committed by Facebook Github Bot
parent 8fe5dcf699
commit 2486b0ba82
35 changed files with 2105 additions and 464 deletions

View File

@ -488,14 +488,16 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_rref_proto.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp
${TORCH_SRC_DIR}/csrc/jit/export.cpp
${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp

View File

@ -10,11 +10,20 @@ import torch
import torch.distributed as dist
import torch.distributed.rpc_backend_registry as rpc_backend_registry
from collections import namedtuple
from torch.distributed.internal_rpc_utils import _internal_rpc_pickler, PythonUDF
from common_utils import load_tests
from torch.distributed.rpc_api import RpcBackend
from dist_utils import dist_init
from torch.distributed import ProcessGroupAgent
from torch.distributed.internal_rpc_utils import _internal_rpc_pickler, PythonUDF
from torch.distributed.rpc_api import RpcBackend
def requires_process_group_agent(func):
from torch.distributed.rpc_api import _agent
return unittest.skipUnless(
isinstance(_agent, ProcessGroupAgent),
"Only ProcessGroupAgent supports global termination detection",
)
BACKEND = getenv("RPC_BACKEND", RpcBackend.PROCESS_GROUP)
RPC_INIT_URL = getenv("RPC_INIT_URL", "")
@ -97,6 +106,10 @@ def my_complex_tensor_function(list_input, tensor_class_input, dict_input):
return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2])
def my_rref_function(rref_a, rref_b):
return rref_a.to_here() + rref_b.to_here()
def no_result():
print("do nothing")
@ -110,18 +123,52 @@ def multi_layer_nested_async_rpc(dst, world_size, ttl):
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
dist.rpc(
dist.rpc_async(
current_dst,
multi_layer_nested_async_rpc,
args=(
next_dst,
world_size,
ttl - 1
),
async_call=True
)
)
return 0
def nested_rref(dst):
return (
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)),
dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2))
)
def nested_remote(dst):
rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3))
return rref.to_here()
def rref_forward_chain(dst, world_size, rref, ttl):
if ttl > 0:
current_dst = "worker{}".format(dst)
next_dst = (dst + 1) % world_size
ret_rref = dist.remote(
current_dst,
rref_forward_chain,
args=(
next_dst,
world_size,
rref,
ttl - 1
)
)
return [ret_rref]
else:
return rref.to_here()
def rpc_return_rref(dst):
return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
def light_rpc():
return 0
@ -155,32 +202,40 @@ class RpcTest(object):
def test_worker_id(self):
n = self.rank + 1
peer_rank = n % self.world_size
self_worker_id = dist.get_worker_id()
peer_worker_id = dist.get_worker_id("worker{}".format(peer_rank))
self_worker_info = dist.get_worker_info()
peer_worker_info = dist.get_worker_info("worker{}".format(peer_rank))
self.assertEqual(self_worker_id.name, "worker{}".format(self.rank))
self.assertEqual(peer_worker_id.name, "worker{}".format(peer_rank))
self.assertEqual(self_worker_info.name, "worker{}".format(self.rank))
self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank))
with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"):
unknown_worker_id = dist.get_worker_id("WorkerUnknown")
unknown_worker_id = dist.get_worker_info("WorkerUnknown")
@dist_init
def test_self_add(self):
self_worker_id = dist.get_worker_id()
self_worker_info = dist.get_worker_info()
self_worker_name = "worker{}".format(self.rank)
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc_sync(self_worker_id, torch.add, args=(torch.ones(2, 2), 1))
dist.rpc_sync(
self_worker_info,
torch.add,
args=(torch.ones(2, 2), 1)
)
with self.assertRaisesRegex(
RuntimeError, "does not support making RPC calls to self"
):
dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
dist.rpc_sync(
self_worker_name,
torch.add,
args=(torch.ones(2, 2), 1)
)
@mock.patch.object(torch.distributed.autograd, "_init")
@mock.patch.object(torch.distributed.rpc_api, "init_rref_context")
@mock.patch.object(torch.distributed.rpc_api, "_init_rref_context")
def test_register_rpc_backend_and_init_rpc_backend(
self, mock_init_rref_context, mock_dist_autograd_init
):
@ -254,7 +309,7 @@ class RpcTest(object):
dist.init_model_parallel(self_name="")
# If the number in the message does not match, it is likely that the
# value of MAX_NAME_LEN in RPC WorkerId has changed.
# value of MAX_NAME_LEN in RPC WorkerInfo has changed.
with self.assertRaisesRegex(RuntimeError, "shorter than 128"):
dist.init_model_parallel(
self_name="".join(["a" for _ in range(500)]),
@ -279,10 +334,10 @@ class RpcTest(object):
def test_add_with_id(self):
n = self.rank + 1
dst_rank = n % self.world_size
workder_id = dist.get_worker_id("worker{}".format(dst_rank))
workder_info = dist.get_worker_info("worker{}".format(dst_rank))
ret = dist.rpc_sync(
workder_id, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n))
)
self.assertEqual(ret, torch.ones(n, n) * 2)
@ -424,9 +479,17 @@ class RpcTest(object):
def test_py_multi_async_call(self):
n = self.rank + 1
dst_rank = n % self.world_size
dst_worker_id = dist.get_worker_id("worker{}".format(dst_rank))
fut1 = dist.rpc_async(dst_worker_id, MyClass.my_static_method, args=(n + 10,))
fut2 = dist.rpc_async(dst_worker_id, min, args=(n, n + 1, n + 2))
dst_worker_info = dist.get_worker_info("worker{}".format(dst_rank))
fut1 = dist.rpc_async(
dst_worker_info,
MyClass.my_static_method,
args=(n + 10,)
)
fut2 = dist.rpc_async(
dst_worker_info,
min,
args=(n, n + 1, n + 2)
)
self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10))
self.assertEqual(fut2.wait(), min(n, n + 1, n + 2))
@ -441,9 +504,11 @@ class RpcTest(object):
def test_py_tensors(self):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc("worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(n, n), torch.ones(n, n)))
ret = dist.rpc_sync(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(n, n), torch.ones(n, n))
)
self.assertEqual(ret,
my_tensor_function(torch.ones(n, n),
torch.ones(n, n)))
@ -454,10 +519,11 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
for i in range(100):
fut = dist.rpc("worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(i, i), torch.ones(i, i)),
async_call=True)
fut = dist.rpc_async(
"worker{}".format(dst_rank),
my_tensor_function,
args=(torch.ones(i, i), torch.ones(i, i))
)
futs.append(fut)
j = 0
@ -474,9 +540,11 @@ class RpcTest(object):
a = [torch.ones(n, n), torch.ones(n, n)]
b = TensorClass(build_complex_tensors())
c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)}
ret = dist.rpc("worker{}".format(dst_rank),
my_complex_tensor_function,
args=(a, b, c))
ret = dist.rpc_sync(
"worker{}".format(dst_rank),
my_complex_tensor_function,
args=(a, b, c)
)
self.assertEqual(ret, my_complex_tensor_function(a, b, c))
@dist_init
@ -484,9 +552,11 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
ret = dist.rpc("worker{}".format(dst_rank),
run_nested_pickle,
args=(MyPickleClass(), torch.ones(2, 2)))
ret = dist.rpc_sync(
"worker{}".format(dst_rank),
run_nested_pickle,
args=(MyPickleClass(), torch.ones(2, 2))
)
m = MyPickleClass()
m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2)))
@ -497,7 +567,11 @@ class RpcTest(object):
n = self.rank + 1
dst_rank = n % self.world_size
with self.assertRaisesRegex(Exception, "TypeError"):
ret = dist.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,))
ret = dist.rpc_sync(
"worker{}".format(dst_rank),
no_result,
args=(10,)
)
@dist_init
def test_py_raise_in_user_func(self):
@ -557,8 +631,9 @@ class RpcTest(object):
)
self.assertEqual(rref.to_here(), torch.ones(n, n) * 2)
@dist_init
def test_multi_builtin_remote_ret(self):
def _test_multi_remote_call(
self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}
):
m = 10
n = self.rank + 1
dst_rank = n % self.world_size
@ -566,19 +641,171 @@ class RpcTest(object):
expected = []
for i in range(m):
n = n + i
rrefs.append(
dist.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), torch.ones(n, n)),
)
)
expected.append(torch.ones(n, n) * 2)
rrefs.append(dist.remote(
'worker{}'.format(dst_rank),
fn,
args=args_fn(n),
kwargs=kwargs_fn(n)
))
expected.append(fn(*args_fn(n), **kwargs_fn(n)))
for i in range(m):
self.assertEqual(rrefs[i].to_here(), expected[i])
@dist_init
@requires_process_group_agent
def test_multi_builtin_remote_ret(self):
def args_fn(n):
return (torch.ones(n, n), torch.ones(n, n))
self._test_multi_remote_call(torch.add, args_fn=args_fn)
@dist_init
@requires_process_group_agent
def test_py_udf_remote(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank),
my_function,
kwargs={"a": n, "b": n + 1, "c": n + 2},
)
self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2))
@dist_init
@requires_process_group_agent
def test_multi_py_udf_remote(self):
def kwargs_fn(n):
return {
"a" : torch.ones(n, n),
"b" : torch.ones(n, n),
"c" : torch.ones(n, n),
}
self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn)
@dist_init
@requires_process_group_agent
def test_py_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
'worker{}'.format(dst_rank),
torch.add,
args=(torch.ones(n, n), 2)
)
rref_b = dist.remote(
'worker{}'.format(dst_rank),
torch.add,
args=(torch.ones(n, n), 1)
)
rref_c = dist.remote(
'worker{}'.format(dst_rank),
my_rref_function,
args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_py_rref_args_user_share(self):
n = self.rank + 1
owner_rank = n % self.world_size
user_rank = (n + 1) % self.world_size
rref_a = dist.remote(
'worker{}'.format(owner_rank),
my_function,
args=(torch.ones(n, n), 2, 0)
)
rref_b = dist.remote(
'worker{}'.format(owner_rank),
my_function,
args=(torch.ones(n, n), 1, 0)
)
rref_c = dist.remote(
'worker{}'.format(user_rank),
my_rref_function,
args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_py_rpc_rref_args(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
"worker{}".format(dst_rank),
my_function,
args=(torch.ones(n, n), 2, 0)
)
rref_b = dist.remote(
"worker{}".format(dst_rank),
my_function,
args=(torch.ones(n, n), 1, 0)
)
c = dist.rpc_sync(
"worker{}".format(dst_rank),
my_rref_function,
args=(rref_a, rref_b)
)
self.assertEqual(c, torch.ones(n, n) + 4)
@dist_init
@requires_process_group_agent
def test_nested_remote(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank1),
nested_remote,
args=("worker{}".format(dst_rank2),)
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3)
@dist_init
@requires_process_group_agent
def test_nested_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref_of_rrefs = dist.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),)
)
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
@requires_process_group_agent
def test_nested_rref_stress(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
all_rrefs = []
for _ in range(20):
all_rrefs.append(
dist.remote(
"worker{}".format(dst_rank1),
nested_rref,
args=("worker{}".format(dst_rank2),)
)
)
for i in range(20):
rref_of_rrefs = all_rrefs[i]
rrefs = rref_of_rrefs.to_here()
self.assertEqual(len(rrefs), 2)
self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1)
self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2)
@dist_init
@requires_process_group_agent
def test_multi_layer_nested_async_rpc(self):
# This test will exit right away, but there will be a chain of async
# RPCs. The termination algorithm should detect those messages properly.
@ -589,3 +816,72 @@ class RpcTest(object):
dst_rank = n % self.world_size
multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl)
@dist_init
@requires_process_group_agent
def test_remote_with_exception(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank),
raise_func
)
with self.assertRaisesRegex(Exception, "ValueError"):
rref.to_here()
@dist_init
@requires_process_group_agent
def test_rpc_return_rref(self):
n = self.rank + 1
dst_rank1 = n % self.world_size
dst_rank2 = (n + 1) % self.world_size
rref = dist.rpc_sync(
"worker{}".format(dst_rank1),
rpc_return_rref,
args=("worker{}".format(dst_rank2),)
)
self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1)
@dist_init
@requires_process_group_agent
def test_rref_forward_chain(self):
ttl = 8
n = self.rank + 1
dst_rank = n % self.world_size
rref = dist.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), 1)
)
ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl)
for i in range(ttl):
self.assertEqual(len(ret_rref), 1)
ret_rref = ret_rref[0].to_here()
ret = ret_rref
self.assertEqual(ret, torch.add(torch.ones(n, n), 1))
@dist_init
@requires_process_group_agent
def test_remote_same_worker(self):
n = self.rank + 1
dst_rank = n % self.world_size
rref_a = dist.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), 2)
)
rref_b = dist.remote(
"worker{}".format(dst_rank),
torch.add,
args=(torch.ones(n, n), 1)
)
rref_c = dist.remote(
"worker{}".format(dst_rank),
my_rref_function,
args=(rref_a, rref_b)
)
self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4)

View File

@ -56,14 +56,16 @@ libtorch_sources = [
"torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp",
"torch/csrc/distributed/rpc/future_message.cpp",
"torch/csrc/distributed/rpc/message.cpp",
"torch/csrc/distributed/rpc/python_remote_call.cpp",
"torch/csrc/distributed/rpc/python_udf_call.cpp",
"torch/csrc/distributed/rpc/python_udf_resp.cpp",
"torch/csrc/distributed/rpc/request_callback.cpp",
"torch/csrc/distributed/rpc/rpc_with_autograd.cpp",
"torch/csrc/distributed/rpc/rref_proto.cpp",
"torch/csrc/distributed/rpc/script_call.cpp",
"torch/csrc/distributed/rpc/script_remote_call.cpp",
"torch/csrc/distributed/rpc/script_rref_proto.cpp",
"torch/csrc/distributed/rpc/script_resp.cpp",
"torch/csrc/distributed/rpc/types.cpp",
"torch/csrc/distributed/rpc/utils.cpp",
"torch/csrc/Exceptions.cpp",
"torch/csrc/jit/autodiff.cpp",
@ -269,13 +271,13 @@ def add_torch_libs():
"torch/csrc/distributed/c10d/reducer.cpp",
"torch/csrc/distributed/rpc/init.cpp",
"torch/csrc/distributed/rpc/process_group_agent.cpp",
"torch/csrc/distributed/rpc/py_rref.cpp",
"torch/csrc/distributed/rpc/python_functions.cpp",
"torch/csrc/distributed/rpc/python_rpc_handler.cpp",
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
"torch/csrc/distributed/rpc/rpc_agent.cpp",
"torch/csrc/distributed/rpc/rref.cpp",
"torch/csrc/distributed/rpc/rref_context.cpp",
"torch/csrc/distributed/rpc/types.cpp",
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -229,13 +229,13 @@ if (USE_DISTRIBUTED)
${TORCH_SRC_DIR}/csrc/distributed/c10d/reducer.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/py_rref.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)

View File

@ -7,7 +7,6 @@ namespace rpc {
const Message& FutureMessage::wait() {
std::unique_lock<std::mutex> lock(mutex_);
finished_cv_.wait(lock, [this] { return completed_.load(); });
return message_;
}

View File

@ -2,6 +2,7 @@
#include <torch/csrc/distributed/rpc/future_message.h>
#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref.h>
@ -29,9 +30,9 @@ PyObject* rpc_init(PyObject* /* unused */) {
auto module = py::handle(dist_module).cast<py::module>();
auto workerId = shared_ptr_class_<WorkerId>(module, "WorkerId")
.def_readonly("name", &WorkerId::name_)
.def_readonly("id", &WorkerId::id_);
auto workerInfo = shared_ptr_class_<WorkerInfo>(module, "WorkerInfo")
.def_readonly("name", &WorkerInfo::name_)
.def_readonly("id", &WorkerInfo::id_);
auto rpcAgent =
shared_ptr_class_<RpcAgent>(module, "RpcAgent")
@ -42,13 +43,33 @@ PyObject* rpc_init(PyObject* /* unused */) {
&RpcAgent::sync,
py::call_guard<py::gil_scoped_release>());
auto rref =
shared_ptr_class_<RRef>(module, "RRef")
.def("owner", &RRef::owner, py::call_guard<py::gil_scoped_release>())
auto pyRRef =
shared_ptr_class_<PyRRef>(module, "RRef")
.def(
// not releasing GIL here to avoid context switch on getters
"is_owner",
&PyRRef::isOwner)
.def(
// not releasing GIL here to avoid context switch on getters
"owner",
&PyRRef::owner)
.def(
"to_here",
[&](RRef& rref) { return torch::jit::toPyObject(rref.toHere()); },
py::call_guard<py::gil_scoped_release>());
&PyRRef::toHere,
py::call_guard<py::gil_scoped_release>())
.def(
"local_value",
&PyRRef::localValue,
py::call_guard<py::gil_scoped_release>())
.def(py::pickle(
[](const PyRRef& self) {
// __getstate__
return self.pickle();
},
[](py::tuple t) { // NOLINT
// __setstate__
return PyRRef::unpickle(t);
}));
auto futureMessage =
shared_ptr_class_<FutureMessage>(module, "FutureMessage")
@ -64,14 +85,14 @@ PyObject* rpc_init(PyObject* /* unused */) {
py::arg("process_group"),
py::arg("num_send_recv_threads") = 4)
.def(
"get_worker_id",
(const WorkerId& (ProcessGroupAgent::*)(void)const) &
RpcAgent::getWorkerId,
"get_worker_info",
(const WorkerInfo& (ProcessGroupAgent::*)(void)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_id",
(const WorkerId& (ProcessGroupAgent::*)(const std::string&)const) &
ProcessGroupAgent::getWorkerId,
"get_worker_info",
(const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) &
ProcessGroupAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"join",
@ -82,14 +103,18 @@ PyObject* rpc_init(PyObject* /* unused */) {
&ProcessGroupAgent::sync,
py::call_guard<py::gil_scoped_release>());
module.def("init_rref_context", [](std::shared_ptr<RpcAgent> agent) {
module.def("_init_rref_context", [](std::shared_ptr<RpcAgent> agent) {
RRefContext::initInstance(std::move(agent));
});
module.def("_destroy_rref_context", []() {
RRefContext::getInstance()->destroyInstance();
});
module.def(
"invoke_rpc_builtin",
[](RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
@ -99,8 +124,8 @@ PyObject* rpc_init(PyObject* /* unused */) {
module.def(
"invoke_rpc_python_udf",
[](RpcAgent& agent,
const WorkerId& dst,
const std::string& pickledPythonUDF,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
return pyRpcPythonUdf(agent, dst, pickledPythonUDF, tensors);
});
@ -108,13 +133,22 @@ PyObject* rpc_init(PyObject* /* unused */) {
module.def(
"invoke_remote_builtin",
[](RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
return pyRemoteBuiltin(agent, dst, opName, args, kwargs);
});
module.def(
"invoke_remote_python_udf",
[](RpcAgent& agent,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
return pyRemotePythonUdf(agent, dst, pickledPythonUDF, tensors);
});
Py_RETURN_TRUE;
}

View File

@ -65,24 +65,28 @@ const MessageType& Message::type() const {
}
bool Message::isRequest() const {
return MessageType::SCRIPT_CALL == type_ ||
MessageType::PYTHON_CALL == type_ || MessageType::REMOTE_CALL == type_ ||
MessageType::MESSAGE_WITH_AUTOGRAD_REQ == type_ ||
MessageType::RREF_FETCH_CALL == type_ ||
MessageType::RREF_USER_CREATE == type_ ||
MessageType::RREF_USER_DELETE == type_;
}
bool Message::requiresResponse() const {
return MessageType::SCRIPT_CALL == type_ ||
MessageType::PYTHON_CALL == type_ ||
MessageType::MESSAGE_WITH_AUTOGRAD_REQ == type_ ||
MessageType::RREF_FETCH_CALL == type_;
return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops
MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs
MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops
MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs
// RRef related internal messages
MessageType::SCRIPT_RREF_FETCH_CALL == type_ ||
MessageType::PYTHON_RREF_FETCH_CALL == type_ ||
MessageType::RREF_USER_DELETE == type_ ||
MessageType::RREF_CHILD_ACCEPT == type_ ||
MessageType::RREF_FORK_REQUEST == type_ ||
// Autograd message
MessageType::MESSAGE_WITH_AUTOGRAD_REQ == type_;
}
bool Message::isResponse() const {
return MessageType::SCRIPT_RET == type_ || MessageType::PYTHON_RET == type_ ||
MessageType::RREF_FETCH_RET == type_ ||
return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops
MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs
MessageType::REMOTE_RET == type_ || // ret of dist.remote
MessageType::RREF_FETCH_RET == type_ || // ret on RRef::toHere()
MessageType::EXCEPTION == type_ || // propagate back exceptions
MessageType::RREF_ACK == type_ || // ret of other types
// Autograd response
MessageType::MESSAGE_WITH_AUTOGRAD_RESP == type_;
}

View File

@ -8,20 +8,36 @@ namespace distributed {
namespace rpc {
enum MessageType {
// messages for dist.rpc on builtin operators
SCRIPT_CALL = 0,
SCRIPT_RET,
PYTHON_CALL,
PYTHON_RET,
REMOTE_CALL,
RREF_FETCH_CALL,
RREF_FETCH_RET,
RREF_USER_CREATE,
RREF_USER_DELETE,
MESSAGE_WITH_AUTOGRAD_REQ,
MESSAGE_WITH_AUTOGRAD_RESP,
SHUTDOWN,
EXCEPTION,
UNKNOWN
SCRIPT_RET = 1,
// messages for dist.rpc on Python UDF
PYTHON_CALL = 2,
PYTHON_RET = 3,
// messages for dist.remote on builtin operators and Python UDF
SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator
PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF
REMOTE_RET = 6, // A remote call on a Python UDF
// RRef related internal messages
SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef<IValue> fetches value from owner
PYTHON_RREF_FETCH_CALL = 8, // A UserRRef<py::object> fetches value from owner
RREF_FETCH_RET = 9, // An OwnerRRef sends value to user
RREF_USER_DELETE = 10, // A UserRRef tells the owner to deref
RREF_FORK_REQUEST = 11, // A child UserRRef tells the owner about itself
RREF_CHILD_ACCEPT = 12, // A child UserRRef tells parent that owner knows it
RREF_ACK = 13, // ACK to internal RRef messages
// Messages with autograd info
MESSAGE_WITH_AUTOGRAD_REQ = 14,
MESSAGE_WITH_AUTOGRAD_RESP = 15,
// Other internal message types
SHUTDOWN = 16,
EXCEPTION = 17,
UNKNOWN = 18
};
// A message to be sent/received by an RpcAgent.
@ -73,7 +89,6 @@ class TORCH_API Message final {
const MessageType& type() const;
bool isRequest() const;
bool requiresResponse() const;
bool isResponse() const;
bool isShutdown() const;

View File

@ -74,18 +74,18 @@ std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
//////////////////////// ProcessGroupAgent /////////////////////////////////
void ProcessGroupAgent::collectNames() {
const std::string& workerName = workerId_.name_;
const std::string& workerName = workerInfo_.name_;
const auto worldSize = pg_->getSize();
// use c10d allgather to collect names
torch::Tensor nameTensor =
torch::zeros({WorkerId::MAX_NAME_LEN}, torch::kChar);
torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar);
memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length());
std::vector<torch::Tensor> inputName = {nameTensor};
std::vector<std::vector<torch::Tensor>> outputNames(1);
for (int i = 0; i < worldSize; ++i) {
outputNames[0].emplace_back(
torch::empty({WorkerId::MAX_NAME_LEN}, {torch::kChar}));
torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar}));
}
pg_->allgather(outputNames, inputName)->wait();
@ -109,7 +109,7 @@ ProcessGroupAgent::ProcessGroupAgent(
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads)
: RpcAgent(
WorkerId(std::move(workerName), pg->getRank()),
WorkerInfo(std::move(workerName), pg->getRank()),
c10::guts::make_unique<RequestCallbackImpl>()),
pg_(std::move(pg)),
sendCounts_(pg_->getSize()),
@ -123,12 +123,12 @@ ProcessGroupAgent::ProcessGroupAgent(
"ProcessGroupAgent requires world_size to "
"be at least 2, but got ",
nameMap_.size());
auto workerRankIter = nameMap_.find(workerId_.name_);
auto workerRankIter = nameMap_.find(workerInfo_.name_);
TORCH_CHECK(
workerRankIter != nameMap_.end(),
"Failed to resolve worker "
"name ",
workerId_.name_,
workerInfo_.name_,
" to a ProcessGroup rank.");
TORCH_CHECK(
pg_->getRank() == workerRankIter->second,
@ -143,9 +143,9 @@ ProcessGroupAgent::ProcessGroupAgent(
tmpWorkerIds[entry.second] = entry.first;
}
workerIds_.reserve(pg_->getSize());
allWorkerInfo_.reserve(pg_->getSize());
for (int rank = 0; rank < (int)tmpWorkerIds.size(); ++rank) {
workerIds_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
}
// construct PythonRpcHandler singleton here
@ -153,17 +153,17 @@ ProcessGroupAgent::ProcessGroupAgent(
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
}
const WorkerId& ProcessGroupAgent::getWorkerId(
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(
const std::string& workerName) const {
const auto idIter = nameMap_.find(workerName);
TORCH_CHECK(
idIter != nameMap_.end(), "Unknown destination worker ", workerName);
return workerIds_[idIter->second];
return allWorkerInfo_[idIter->second];
}
const WorkerId& ProcessGroupAgent::getWorkerId(worker_id_t id) const {
return workerIds_[id];
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const {
return allWorkerInfo_[id];
}
void ProcessGroupAgent::join() {
@ -174,9 +174,13 @@ void ProcessGroupAgent::join() {
// 2. A GLOO process cannot send message to itself. (there is an ongoing
// effort to fix this problem).
sync();
std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait(lock, [this] { return futures_.empty(); });
lock.unlock();
pg_->barrier()->wait();
int dst = (pg_->getRank() + 1) % pg_->getSize();
enqueueSend(
SendWork(workerIds_[dst], Message({}, {}, MessageType::SHUTDOWN)));
SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN)));
threadPool_.waitWorkComplete();
listenerThread_.join();
}
@ -198,7 +202,6 @@ bool ProcessGroupAgent::hasPendingMessage() {
std::vector<torch::Tensor> inputSnapshot = {
torch::from_blob(snapshot.data(), {2, worldSize}, {torch::kInt64})};
// allgather both send and recv messages in one shot
std::vector<std::vector<torch::Tensor>> outputSnapshots(1);
@ -245,7 +248,7 @@ void ProcessGroupAgent::sync() {
}
std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
const WorkerId& to,
const WorkerInfo& to,
Message&& message) {
TORCH_CHECK(
to.id_ != (worker_id_t)pg_->getRank(),
@ -259,7 +262,7 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
auto requestId = nextId();
auto future = std::make_shared<FutureMessage>();
if (message.requiresResponse()) {
if (message.isRequest()) {
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[requestId] = future;
@ -271,14 +274,14 @@ std::shared_ptr<FutureMessage> ProcessGroupAgent::send(
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
// longer be alive when the ``SendWork`` is executed. For example, the
// application could query the ``WorkerId`` using name through the
// ``RpcAgent::getWorkerId`` API, and pass the ``WorkerId`` back here, so we
// have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerId``
// application could query the ``WorkerInfo`` using name through the
// ``RpcAgent::getWorkerInfo`` API, and pass the ``WorkerInfo`` back here, so
// we have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerInfo``
// reference on Python side could die before ``SendWork`` uses it, and Pybind
// will not keep the Python reference alive even if it originally comes from
// the C++ land. Hence, we have to explicitly use the ``workerId`` in the C++
// land.
enqueueSend(SendWork(workerIds_[to.id_], std::move(message)));
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
// C++ land.
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
return future;
}
@ -340,18 +343,23 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
(char*)payload.storage().data<signed char>(), payload.numel()));
Message message = deserialize(work.type_, ss);
if (message.requiresResponse()) {
auto response = cb_->operator()(message);
send(work.from_, std::move(response));
} else if (message.isRequest()) {
cb_->operator()(message);
if (message.isRequest()) {
send(work.from_, cb_->operator()(message));
} else if (message.isResponse()) {
auto id = message.id();
std::shared_ptr<FutureMessage> fm = nullptr;
{
std::lock_guard<std::mutex> lock{futureMutex_};
fm = futures_[id];
}
// Not holding lock on markCompleted as this could run callbacks that
// call agent_->send
fm->markCompleted(std::move(message));
{
std::lock_guard<std::mutex> lock{futureMutex_};
futures_[id]->markCompleted(std::move(message));
futures_.erase(id);
}
futureCV_.notify_all();
} else {
// TODO: pass the error back to the caller instead of crashing here.
TORCH_INTERNAL_ASSERT(
@ -378,7 +386,7 @@ void ProcessGroupAgent::listenLoop() {
// FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked
// before logging, but it is not appropriate to call InitGoogleLogging()
// here either.
LOG(INFO) << "Shutting down ProcessGroupAgent " << workerId_.name_
LOG(INFO) << "Shutting down ProcessGroupAgent " << workerInfo_.name_
<< std::endl;
return;
}
@ -386,7 +394,7 @@ void ProcessGroupAgent::listenLoop() {
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
pg_->recv(tensors, srcRank, pg_->getRank())->wait();
enqueueRecv(RecvWork(workerIds_[srcRank], type, std::move(tensors[0])));
enqueueRecv(RecvWork(allWorkerInfo_[srcRank], type, std::move(tensors[0])));
}
}

View File

@ -15,20 +15,20 @@ namespace rpc {
// SendWork and RecvWork will be put into a task queue, and later picked up by
// worker threads from the same ThreadPool.
struct SendWork {
SendWork(const WorkerId& to, Message&& message)
SendWork(const WorkerInfo& to, Message&& message)
: to_(to), message_(message) {}
const WorkerId& to_;
const WorkerInfo& to_;
Message message_;
};
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
// to allow us to run serialization/deserialization in the worker threads.
struct RecvWork {
RecvWork(const WorkerId& from, MessageType type, torch::Tensor&& payload)
RecvWork(const WorkerInfo& from, MessageType type, torch::Tensor&& payload)
: from_(from), type_(type), payload_(payload) {}
const WorkerId& from_;
const WorkerInfo& from_;
const MessageType type_;
torch::Tensor payload_;
};
@ -40,9 +40,9 @@ class ProcessGroupAgent : public RpcAgent {
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads = 4);
const WorkerId& getWorkerId(const std::string& workerName) const override;
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
const WorkerId& getWorkerId(worker_id_t id) const override;
const WorkerInfo& getWorkerInfo(worker_id_t id) const override;
void join() override;
@ -52,7 +52,7 @@ class ProcessGroupAgent : public RpcAgent {
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
std::shared_ptr<FutureMessage> send(const WorkerId& to, Message&& message)
std::shared_ptr<FutureMessage> send(const WorkerInfo& to, Message&& message)
override;
private:
@ -99,13 +99,13 @@ class ProcessGroupAgent : public RpcAgent {
bool hasPendingMessage();
int64_t nextId() {
return nextId_++;
return ++nextId_;
}
std::shared_ptr<c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, int> nameMap_;
std::vector<WorkerId> workerIds_;
std::vector<WorkerInfo> allWorkerInfo_;
// record the number of messages sent to and received from each peer. The recv
// counter is only marked after the message is processed. Join uses allgather
// to collect all counts from all peers, uses these counters to detect global
@ -129,7 +129,8 @@ class ProcessGroupAgent : public RpcAgent {
// This is just a temporary solution for (2).
ThreadPool threadPool_;
std::unordered_map<int64_t, std::shared_ptr<FutureMessage>> futures_;
std::mutex futureMutex_;
mutable std::mutex futureMutex_;
mutable std::condition_variable futureCV_;
};
} // namespace rpc

View File

@ -0,0 +1,135 @@
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
// Constants below are used in PyRRef pickling and unpickling. PyRRef is
// converted into a py::tuple in pickling, and reconstructed from the py::tuple
// in unpickling.
constexpr int RFD_IDX = 0; // index of RRefForkData
constexpr int TYPE_IDX = 1; // index of type (py::object or IValue)
// number of data fields in the py::tuple.
// NB: if more fields are added, make sure this field is also bumped
constexpr int RREF_TUPLE_SIZE = 2;
} // namespace
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}
worker_id_t PyRRef::owner() const {
return rref_->owner();
}
py::object PyRRef::toHere() {
if (rref_->isOwner()) {
if (rref_->isPyObj()) {
const py::object& value =
std::static_pointer_cast<OwnerRRef<py::object>>(rref_)->getValue();
{
// acquiring GIL as the return statement construct a new py::object from
// a const reference.
AutoGIL ag;
return value;
}
} else {
IValue value =
std::static_pointer_cast<OwnerRRef<IValue>>(rref_)->getValue();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
AutoGIL ag;
return torch::jit::toPyObject(std::move(value));
}
}
} else {
if (rref_->isPyObj()) {
// UserRRef<py::object>::toHere() calls python_rpc_handler which acquires
// GIL.
return std::static_pointer_cast<UserRRef<py::object>>(rref_)->toHere();
} else {
IValue value =
std::static_pointer_cast<UserRRef<IValue>>(rref_)->toHere();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
AutoGIL ag;
return torch::jit::toPyObject(std::move(value));
}
}
}
}
py::object PyRRef::localValue() {
TORCH_CHECK(
rref_->isOwner(),
"Cannot call localValue() on a non-local reference. Call it on ",
RRefContext::getInstance()->getWorkerName());
if (rref_->isPyObj()) {
const py::object& value =
std::dynamic_pointer_cast<OwnerRRef<py::object>>(rref_)->getValue();
{
// acquiring GIL as the return statement construct a new py::object from
// a const reference.
AutoGIL ag;
return value;
}
} else {
auto value =
std::dynamic_pointer_cast<OwnerRRef<IValue>>(rref_)->getValue();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
// grabbing the GIL.
AutoGIL ag;
return torch::jit::toPyObject(std::move(value));
}
}
}
py::tuple PyRRef::pickle() const {
auto& ctx = RRefContext::getInstance();
// TODO: use a dispatch table to pickle/unpickle an RRef, and only only
// install the dispatch table only when there are indeed RPC activities. As
// a counter example, checkpointing a model with RRefs should not trigger
// forks to be added as a fork or a child.
auto rfd = ctx->prepareChildFork(rref_);
return py::make_tuple(rfd.toPyTuple(), rref_->isPyObj());
}
PyRRef PyRRef::unpickle(const py::tuple& t) {
TORCH_INTERNAL_ASSERT(
t.size() == RREF_TUPLE_SIZE, "Pickled RRef must contain 2 numbers.");
auto& ctx = RRefContext::getInstance();
auto rfd = RRefForkData::fromPyTuple(t[RFD_IDX].cast<py::tuple>());
std::shared_ptr<RRef> rref = nullptr;
bool isPyObj = t[TYPE_IDX].cast<bool>();
if (isPyObj) {
rref = ctx->getOrCreateRRef<py::object>(rfd);
} else {
rref = ctx->getOrCreateRRef<IValue>(rfd);
}
ctx->notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
return PyRRef(std::move(rref));
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,30 @@
#pragma once
#include <torch/csrc/distributed/rpc/rref.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
// Python wrapper of an RRef shared_ptr that supports Python
// pickle and unpickle.
class PyRRef {
public:
explicit PyRRef(std::shared_ptr<RRef> rref);
bool isOwner() const;
worker_id_t owner() const;
py::object toHere();
py::object localValue();
py::tuple pickle() const;
static PyRRef unpickle(const py::tuple& t);
private:
std::shared_ptr<RRef> rref_;
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -6,6 +6,18 @@
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
@ -47,8 +59,13 @@ std::shared_ptr<Operator> matchBuiltinOp(
", kwargs: ",
kwargs,
") to a builtin operator");
}
// builtin operators.
void finishAcceptUserRRef(const Message& message) {
RRefContext::handleException(message);
auto rr = RemoteRet::fromMessage(message);
auto& ctx = RRefContext::getInstance();
ctx->delPendingUser(rr->forkId());
}
} // namespace
@ -71,6 +88,7 @@ py::object toPyObjInternal(RpcCommandBase& rpc, MessageType messageType) {
case MessageType::PYTHON_RET: {
// TODO: Try to avoid a copy here.
auto& resp = static_cast<PythonUDFResp&>(rpc);
return PythonRpcHandler::getInstance().loadPythonUDFResult(
resp.pickledPayload(), resp.tensors());
}
@ -97,7 +115,7 @@ py::object toPyObj(const Message& message) {
std::shared_ptr<FutureMessage> pyRpcBuiltin(
RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
@ -127,9 +145,9 @@ std::shared_ptr<FutureMessage> pyRpcBuiltin(
}
}
std::shared_ptr<RRef> pyRemoteBuiltin(
PyRRef pyRemoteBuiltin(
RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs) {
@ -137,22 +155,26 @@ std::shared_ptr<RRef> pyRemoteBuiltin(
auto op = matchBuiltinOp(opName, args, kwargs, stack);
auto& ctx = RRefContext::getInstance();
auto userRRef = ctx->createUserRRef(dst.id_);
agent.send(
// TODO: support creating RRefs on a local object.
TORCH_INTERNAL_ASSERT(
ctx->getWorkerId() != dst.id_,
"Does not support creating RRef on self yet.");
auto userRRef = ctx->createUserRRef<IValue>(dst.id_);
auto fm = agent.send(
dst,
ScriptRemoteCall(
op,
std::move(stack),
userRRef->id().toIValue(),
userRRef->forkId().toIValue())
op, std::move(stack), userRRef->rrefId(), userRRef->forkId())
.toMessage());
return userRRef;
ctx->addPendingUser(userRRef->forkId(), userRRef);
fm->addCallback(finishAcceptUserRRef);
return PyRRef(userRRef);
}
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
RpcAgent& agent,
const WorkerId& dst,
const std::string& pickledPythonUDF,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
return agent.send(
dst,
@ -162,6 +184,30 @@ std::shared_ptr<FutureMessage> pyRpcPythonUdf(
.toMessage());
}
PyRRef pyRemotePythonUdf(
RpcAgent& agent,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors) {
auto& ctx = RRefContext::getInstance();
// TODO: support creating RRefs on a local object.
TORCH_INTERNAL_ASSERT(
ctx->getWorkerId() != dst.id_,
"Does not support creating RRef on self yet.");
auto userRRef = ctx->createUserRRef<py::object>(dst.id_);
auto fm = agent.send(
dst,
PythonRemoteCall(
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors)),
userRRef->rrefId().toIValue(),
userRRef->forkId().toIValue())
.toMessage());
ctx->addPendingUser(userRRef->forkId(), userRRef);
fm->addCallback(finishAcceptUserRRef);
return PyRRef(userRRef);
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,15 +1,8 @@
#pragma once
#include <torch/csrc/distributed/rpc/future_message.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
@ -20,24 +13,30 @@ py::object toPyObj(const Message& message);
std::shared_ptr<FutureMessage> pyRpcBuiltin(
RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs);
std::shared_ptr<FutureMessage> pyRpcPythonUdf(
RpcAgent& agent,
const WorkerId& dst,
const std::string& pickledPythonUDF,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors);
std::shared_ptr<RRef> pyRemoteBuiltin(
PyRRef pyRemoteBuiltin(
RpcAgent& agent,
const WorkerId& dst,
const WorkerInfo& dst,
const std::string& opName,
const py::args& args,
const py::kwargs& kwargs);
PyRRef pyRemotePythonUdf(
RpcAgent& agent,
const WorkerInfo& dst,
std::string& pickledPythonUDF,
std::vector<torch::Tensor>& tensors);
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,53 @@
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/jit/pickle.h>
namespace torch {
namespace distributed {
namespace rpc {
PythonRemoteCall::PythonRemoteCall(
SerializedPyObj&& serializedPyObj,
at::IValue retRRefId,
at::IValue retForkId)
: serializedPyObj_(std::move(serializedPyObj)),
retRRefId_(std::move(retRRefId)),
retForkId_(std::move(retForkId)) {}
Message PythonRemoteCall::toMessage() && {
std::vector<IValue> ivalues = serializedPyObj_.toIValues();
ivalues.emplace_back(retRRefId_);
ivalues.emplace_back(retForkId_);
std::vector<torch::Tensor> tensor_table;
auto payload =
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
return Message(
std::move(payload),
std::move(tensor_table),
MessageType::PYTHON_REMOTE_CALL);
}
std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto values = value.toTuple()->elements();
// remove the last element from values and convert it back to an RRef
auto retForkId = std::move(values.back());
values.pop_back();
auto retRRefId = std::move(values.back());
values.pop_back();
auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
return c10::guts::make_unique<PythonRemoteCall>(
std::move(serializedPyObj), std::move(retRRefId), std::move(retForkId));
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,43 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/pickler.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
class TORCH_API PythonRemoteCall : public RpcCommandBase {
public:
PythonRemoteCall(
SerializedPyObj&& serializedPyObj,
at::IValue retRRefId,
at::IValue retForkId);
inline const SerializedPyObj& serializedPyObj() const {
return serializedPyObj_;
}
inline const at::IValue& retRRefId() const {
return retRRefId_;
}
inline const at::IValue& retForkId() const {
return retForkId_;
}
Message toMessage() && override;
static std::unique_ptr<PythonRemoteCall> fromMessage(const Message& message);
private:
const SerializedPyObj serializedPyObj_;
const at::IValue retRRefId_;
const at::IValue retForkId_;
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -10,6 +10,7 @@ PythonRpcHandler::PythonRpcHandler() {
py::module::import("torch.distributed.internal_rpc_utils");
runUDFFunction_ = module.attr("run_python_udf_internal");
loadResultFunction_ = module.attr("load_python_udf_result_internal");
serializeFunction_ = module.attr("serialize");
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
@ -24,7 +25,8 @@ std::vector<char> PythonRpcHandler::generatePythonUDFResult(
AutoGIL ag;
auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size());
TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr");
py::tuple pres = runUDFFunction_(pargs, requestTensorTable);
py::tuple pres =
serializeFunction_(runUDFFunction_(pargs, requestTensorTable));
const auto& presStr = pres[0].cast<std::string>();
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
std::vector<char> payload(presStr.begin(), presStr.end());
@ -40,6 +42,26 @@ py::object PythonRpcHandler::loadPythonUDFResult(
return loadResultFunction_(pargs, tensorTable);
}
py::object PythonRpcHandler::runPythonUDF(
const SerializedPyObj& serializedObj) {
AutoGIL ag;
return runUDFFunction_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
AutoGIL ag;
py::tuple t = serializeFunction_(obj);
return SerializedPyObj(
t[0].cast<std::string>(), t[1].cast<std::vector<torch::Tensor>>());
}
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
AutoGIL ag;
return loadResultFunction_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
@ -26,6 +27,12 @@ class PYBIND11_EXPORT PythonRpcHandler {
py::object loadPythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& tensorTable);
// Run a pickled Python UDF and return the result py::object
py::object runPythonUDF(const SerializedPyObj& serializedObj);
// Serialized a py::object into a string
SerializedPyObj serialize(const py::object& obj);
// Deserialize a string into a py::object
py::object deserialize(const SerializedPyObj& serializedObj);
private:
PythonRpcHandler();
@ -38,6 +45,7 @@ class PYBIND11_EXPORT PythonRpcHandler {
py::object runUDFFunction_;
py::object loadResultFunction_;
py::object serializeFunction_;
};
} // namespace rpc

View File

@ -4,16 +4,17 @@
#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/future_message.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/python_udf_call.h>
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/rref.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/script_rref_proto.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
@ -56,15 +57,11 @@ std::unique_ptr<RpcCommandBase> RequestCallbackImpl::processRpc(
return c10::guts::make_unique<PythonUDFResp>(
std::move(payload), std::move(responseTensorTable));
}
case MessageType::REMOTE_CALL: {
case MessageType::SCRIPT_REMOTE_CALL: {
auto& src = static_cast<ScriptRemoteCall&>(rpc);
auto rrefId = RRefId::fromIValue(src.retRRefId());
auto forkId = ForkId::fromIValue(src.retForkId());
TORCH_CHECK(rrefId != forkId, "Does not support remote call to self.");
auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx->getOrCreateOwnerRRef<IValue>(rrefId);
auto ownerRRef = ctx->getOrCreateOwnerRRef<IValue>(src.retRRefId());
// TODO: make this asynchronous
// src is only alive within this block, use reference to avoid copy
@ -78,25 +75,60 @@ std::unique_ptr<RpcCommandBase> RequestCallbackImpl::processRpc(
stack.size());
ownerRRef->setValue(std::move(stack.front()));
return nullptr;
ctx->addForkOfOwner(src.retRRefId(), src.retForkId());
return c10::guts::make_unique<RemoteRet>(
src.retRRefId(), src.retForkId());
}
case MessageType::RREF_FETCH_CALL: {
case MessageType::PYTHON_REMOTE_CALL: {
auto& prc = static_cast<PythonRemoteCall&>(rpc);
auto rrefId = RRefId::fromIValue(prc.retRRefId());
auto forkId = ForkId::fromIValue(prc.retForkId());
auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx->getOrCreateOwnerRRef<py::object>(rrefId);
ownerRRef->setValue(
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()));
ctx->addForkOfOwner(rrefId, forkId);
return c10::guts::make_unique<RemoteRet>(rrefId, forkId);
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
// TODO: make this asynchronous
std::shared_ptr<OwnerRRef<IValue>> rref =
RRefContext::getInstance()->getOrCreateOwnerRRef<IValue>(
RRefId::fromIValue(srf.value()));
return c10::guts::make_unique<ScriptRRefFetchRet>(rref->getValue());
ctx->getOrCreateOwnerRRef<IValue>(srf.rrefId());
return c10::guts::make_unique<RRefFetchRet>(
RRefFetchRet({rref->getValue()}));
}
case MessageType::RREF_USER_CREATE: {
auto& sra = static_cast<ScriptRRefCreate&>(rpc);
RRefContext::getInstance()->addFork(sra.valueRef());
return nullptr;
case MessageType::PYTHON_RREF_FETCH_CALL: {
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
// TODO: make this asynchronous
std::shared_ptr<OwnerRRef<py::object>> rref =
ctx->getOrCreateOwnerRRef<py::object>(prf.rrefId());
SerializedPyObj result =
PythonRpcHandler::getInstance().serialize(rref->getValue());
return c10::guts::make_unique<RRefFetchRet>(
RRefFetchRet(result.toIValues()));
}
case MessageType::RREF_USER_DELETE: {
auto& srd = static_cast<ScriptRRefDelete&>(rpc);
RRefContext::getInstance()->delFork(srd.valueRef());
return nullptr;
auto& rud = static_cast<RRefUserDelete&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx->delForkOfOwner(rud.rrefId(), rud.forkId());
return c10::guts::make_unique<RRefAck>();
}
case MessageType::RREF_CHILD_ACCEPT: {
auto& rca = static_cast<RRefChildAccept&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx->delPendingChild(rca.forkId());
return c10::guts::make_unique<RRefAck>();
}
case MessageType::RREF_FORK_REQUEST: {
auto& rfr = static_cast<RRefForkRequest&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx->addForkOfOwner(rfr.rrefId(), rfr.forkId());
return c10::guts::make_unique<RRefAck>();
}
case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

View File

@ -4,15 +4,15 @@ namespace torch {
namespace distributed {
namespace rpc {
constexpr size_t WorkerId::MAX_NAME_LEN;
constexpr size_t WorkerInfo::MAX_NAME_LEN;
RpcAgent::RpcAgent(WorkerId workerId, std::unique_ptr<RequestCallback> cb)
: workerId_(std::move(workerId)), cb_(std::move(cb)) {}
RpcAgent::RpcAgent(WorkerInfo workerId, std::unique_ptr<RequestCallback> cb)
: workerInfo_(std::move(workerId)), cb_(std::move(cb)) {}
RpcAgent::~RpcAgent() = default;
const WorkerId& RpcAgent::getWorkerId() const {
return workerId_;
const WorkerInfo& RpcAgent::getWorkerInfo() const {
return workerInfo_;
}
} // namespace rpc

View File

@ -12,9 +12,9 @@ namespace distributed {
namespace rpc {
// A globally unique ID to identify an RpcAgent
struct WorkerId {
WorkerId(std::string name, int id)
: WorkerId(std::move(name), (worker_id_t)id) {
struct WorkerInfo {
WorkerInfo(std::string name, int id)
: WorkerInfo(std::move(name), (worker_id_t)id) {
TORCH_CHECK(
id <= std::numeric_limits<worker_id_t>::max(),
"RPC worker id ",
@ -22,7 +22,8 @@ struct WorkerId {
" out of bound of int16_t.");
}
WorkerId(std::string name, worker_id_t id) : name_(std::move(name)), id_(id) {
WorkerInfo(std::string name, worker_id_t id)
: name_(std::move(name)), id_(id) {
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
bool validChar =
std::find_if(name_.begin(), name_.end(), [](char c) {
@ -51,9 +52,9 @@ struct WorkerId {
// construction.
class RpcAgent {
public:
// `WorkerId` is the globally unique identifier for this RpcAgent instance. It
// contains a ``name_`` field and an ``id_`` field. ``name_`` is the globally
// unique name for this ``RpcAgent``. It is up to the ``RpcAgent``
// `WorkerInfo` is the globally unique identifier for this RpcAgent instance.
// It contains a ``name_`` field and an ``id_`` field. ``name_`` is the
// globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent``
// implementation to determine how to resolve names. ``id_`` is the globally
// unique ID for this ``RpcAgent``. This should be determined by the
// ``RpcAgent`` implementation.
@ -61,7 +62,7 @@ class RpcAgent {
// ``RpcAgent`` base class makes no assumption on the thread-safeness of the
// ``RequestCallback``. ``RpcAgent`` implementations need to make sure that
// its threading model conform to ``RequestCallback``'s requirement.
RpcAgent(WorkerId id, std::unique_ptr<RequestCallback> cb);
RpcAgent(WorkerInfo id, std::unique_ptr<RequestCallback> cb);
virtual ~RpcAgent();
@ -73,19 +74,20 @@ class RpcAgent {
// when the response arrives. For other message types, the Future should be
// ignored by the caller.
virtual std::shared_ptr<FutureMessage> send(
const WorkerId& to,
const WorkerInfo& to,
Message&& message) = 0;
// Return a reference to the ``WorkerId`` of this RpcAgent.
// Return a reference to the ``WorkerInfo`` of this RpcAgent.
// NB: not using ``c10::optional<const std::string&>`` here because we might
// need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
// implementations to depend on libtorch.
const WorkerId& getWorkerId() const;
const WorkerInfo& getWorkerInfo() const;
// Return a reference to the ``WorkerId`` of the given ``workerName``.
virtual const WorkerId& getWorkerId(const std::string& workerName) const = 0;
// Return a reference to the ``WorkerInfo`` of the given ``workerName``.
virtual const WorkerInfo& getWorkerInfo(
const std::string& workerName) const = 0;
virtual const WorkerId& getWorkerId(worker_id_t id) const = 0;
virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
// Call sync and join all internal threads. This method should be called
// before every RPC process exits.
@ -96,7 +98,7 @@ class RpcAgent {
virtual void sync() = 0;
protected:
const WorkerId workerId_;
const WorkerInfo workerInfo_;
const std::string workerName_;
const std::unique_ptr<RequestCallback> cb_;
};

View File

@ -1,11 +1,27 @@
#include <torch/csrc/distributed/rpc/rref.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/script_rref_proto.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple
} // namespace
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
////////////////////////// RRefForkData /////////////////////////////////
@ -13,27 +29,47 @@ std::atomic<local_id_t> RRefContext::nextLocalId_{0};
RRefForkData::RRefForkData(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId)
: ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId) {}
const ForkId& forkId,
worker_id_t parent)
: ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {}
at::IValue RRefForkData::toIValue() const {
std::vector<at::IValue> ivalues = {
(int64_t)ownerId_, rrefId_.toIValue(), forkId_.toIValue()};
py::tuple RRefForkData::toPyTuple() const {
return py::make_tuple(
ownerId_,
rrefId_.createdOn_,
rrefId_.localId_,
forkId_.createdOn_,
forkId_.localId_,
parent_);
}
return c10::ivalue::Tuple::create(std::move(ivalues));
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
TORCH_INTERNAL_ASSERT(
t.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain 6 numbers.");
worker_id_t ownerId = t[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
t[RREFID_ON_IDX].cast<worker_id_t>(),
t[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
t[FORKID_ON_IDX].cast<worker_id_t>(),
t[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
return RRefForkData(ownerId, rrefId, forkId, parent);
}
RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) {
auto ivalues = ivalue.toTuple()->elements();
TORCH_CHECK(
ivalues.size() == 3,
TORCH_INTERNAL_ASSERT(
ivalues.size() == 4,
"Constructing RRefForkData from ivalue "
"expects a GenericList of 3 elements, but got ",
"expects a GenericList of 4 elements, but got ",
ivalues.size());
int64_t ownerId = ivalues[0].toInt();
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
ownerId < std::numeric_limits<worker_id_t>::max(),
"RRefId createdOn out of range, got ",
ownerId);
@ -41,7 +77,12 @@ RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) {
RRefId rrefId = RRefId::fromIValue(ivalues[1]);
ForkId forkId = ForkId::fromIValue(ivalues[2]);
return RRefForkData(ownerId, rrefId, forkId);
int64_t parent = ivalues[3].toInt();
TORCH_INTERNAL_ASSERT(
parent < std::numeric_limits<worker_id_t>::max(),
"RRefId createdOn out of range, got ",
parent);
return RRefForkData(ownerId, rrefId, forkId, parent);
}
////////////////////////////// RRef /////////////////////////////////////
@ -49,71 +90,100 @@ RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) {
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId)
: ownerId_(ownerId), rrefId_(rrefId) {}
worker_id_t RRef::owner() const {
return ownerId_;
}
const RRefId& RRef::id() const {
return rrefId_;
}
at::IValue RRef::fork() const {
RRefForkData RRef::fork() const {
auto& ctx = RRefContext::getInstance();
return RRefForkData(
ownerId_, rrefId_, RRefContext::getInstance()->genRRefId())
.toIValue();
// NB: does not support sharing RRefs between users
// TODO: notify the owner
ownerId_, rrefId_, ctx->genGloballyUniqueId(), ctx->getWorkerId());
}
////////////////////////// UserRRef /////////////////////////////////////
UserRRef::UserRRef(
template <typename T>
UserRRef<T>::UserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId)
: RRef(ownerId, rrefId), forkId_(forkId) {
AT_ASSERT(
!(forkId_ == rrefId_),
"User RRef's fork ID should not be the same as its rref Id");
if (RRefContext::getInstance()->getWorkerId() == rrefId_.createdOn_) {
// creator user, notify owner.
auto& agent = RRefContext::getInstance()->agent();
agent->send(
agent->getWorkerId(ownerId_),
ScriptRRefCreate(RRefForkData(ownerId_, rrefId_, forkId_).toIValue())
.toMessage());
} else {
AT_ERROR("Does not support sharing RRefs between users yet");
}
// Do nothing,
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
// a RREF_FORK_REQUEST message to the owner.
// (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
// properly notify the owner.
}
UserRRef::~UserRRef() {
template <typename T>
UserRRef<T>::~UserRRef() {
// TODO: queue this in RRefContext instead of doing it here.
auto& ctx = RRefContext::getInstance();
if (ctx->getWorkerId() != ownerId_) {
ctx->agent()->send(
ctx->agent()->getWorkerId(ownerId_),
ScriptRRefDelete(RRefForkData(ownerId_, rrefId_, forkId_).toIValue())
.toMessage());
auto fm = ctx->agent()->send(
ctx->agent()->getWorkerInfo(ownerId_),
RRefUserDelete(rrefId_, forkId_).toMessage());
fm->addCallback(
[](const Message& message) { RRefContext::handleException(message); });
}
}
const ForkId& UserRRef::forkId() const {
template <typename T>
const ForkId& UserRRef<T>::forkId() const {
return forkId_;
}
bool UserRRef::isOwner() const {
return false;
}
IValue UserRRef::toHere() {
template <>
IValue UserRRef<IValue>::toHere() {
auto& agent = RRefContext::getInstance()->agent();
std::shared_ptr<FutureMessage> fm = agent->send(
agent->getWorkerId(ownerId_),
ScriptRRefFetchCall(id().toIValue()).toMessage());
auto srv = ScriptRRefFetchRet::fromMessage(fm->wait());
return srv.value();
agent->getWorkerInfo(ownerId_),
ScriptRRefFetchCall(rrefId()).toMessage());
const Message& message = fm->wait();
RRefContext::handleException(message);
auto rfr = RRefFetchRet::fromMessage(message);
TORCH_INTERNAL_ASSERT(
rfr->values().size() == 1,
"RRef of IValue should contain a single IValue, but got ",
rfr->values().size());
return rfr->values().front();
}
template <>
py::object UserRRef<py::object>::toHere() {
auto& agent = RRefContext::getInstance()->agent();
std::shared_ptr<FutureMessage> fm = agent->send(
agent->getWorkerInfo(ownerId_),
PythonRRefFetchCall(rrefId()).toMessage());
const Message& message = fm->wait();
RRefContext::handleException(message);
auto rfr = RRefFetchRet::fromMessage(message);
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr->values()));
}
template class UserRRef<IValue>;
template class UserRRef<py::object>;
////////////////////////// OwnerRRef /////////////////////////////////////
template <typename T>
const T& OwnerRRef<T>::getValue() const {
// TODO: use callback to make this non-blocking
std::unique_lock<std::mutex> lock(mutex_);
valueCV_.wait(lock, [this] { return value_.has_value(); });
return value_.value();
}
template <typename T>
void OwnerRRef<T>::setValue(T&& value) {
{
std::lock_guard<std::mutex> lock(mutex_);
value_ = std::move(value);
}
valueCV_.notify_all();
}
template class OwnerRRef<IValue>;
template class OwnerRRef<py::object>;
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -4,6 +4,7 @@
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/utils/pybind.h>
#include <atomic>
@ -13,71 +14,245 @@ namespace rpc {
class RRef;
class RRefContext;
template <typename T>
class UserRRef;
// Represents fork of an RRef to be sent over the wire.
//
// In order to preserve correctness of reference counting, each RRefForkData
// **MUST** be deserialized into a RRef. This means that if RRefForkData is to
// be transferred across the network, we need the guarantee that the message
// will *eventually* get to the peer, and that the peer will create a RRef out
// of it. Therefore, no constructor of RRefForkData is exposed, and
// applications should never directly use RRefForkData. All construction are
// done within ``RRef`` and ``RRefContext``.
struct RRefForkData {
at::IValue toIValue() const;
py::tuple toPyTuple() const;
static RRefForkData fromPyTuple(const py::tuple& obj);
const worker_id_t ownerId_;
const RRefId rrefId_;
const ForkId forkId_;
const worker_id_t parent_;
private:
friend class RRef;
friend class RRefContext;
template <typename T>
friend class UserRRef;
RRefForkData(
worker_id_t ownerId,
const RRefId& rrefId_,
const ForkId& forkId_);
const ForkId& forkId_,
worker_id_t parent);
static RRefForkData fromIValue(const at::IValue&);
const worker_id_t ownerId_;
const RRefId rrefId_;
const ForkId forkId_;
};
static_assert(
C10_IS_TRIVIALLY_COPYABLE(RRefForkData),
"RRefForkData must be trivially copyable");
// Note [RRef Protocol]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// [Background]
//
// RRef stands for Remote REFerence. Each RRef is owned by a single worker
// (i.e., owner) and can be used by multiple users. The owner stores the real
// data referenced by its RRefs. RRef needs to support fast and scalable RPC.
// Hence, in the design, we avoid using a single global master to keep RRef
// states, instead owners will keep track of the global reference counts
// for its RRefs. Every RRef can be uniquely identified by a global RRefId,
// which is assigned at the time it is first created either on a user or on the
// owner.
//
// On the owner worker, there is only one OwnerRRef instance, which contains the
// real data, while on user workers, there can be as many UserRRefs as
// necessary, and UserRRef does not hold the data. All usage on the OwnerRRef
// should retrieve the unique OwnerRRef instance using the globally unique
// RRefId. //A UserRRef will be created when it is used as an argument or return
// value in dist.rpc or dist.remote call, but RRef forking and reference
// counting (RC) are completely transparent to applications. Every UserRRef will
// also have its globally unique ForkId.
//
// [Assumptions]
//
// 1. Transient Network Failures
//
// TODO: current RRef implementation does not tolerate failures
//
// The RRef design aims to handle transient network failures by retrying
// messages. Node crashes or permanent network partition is beyond the scope.
// When those incidents occur, the application may take down all workers, revert
// to the previous checkpoint, and resume training.
//
// 2. Non-idempotent UDFs
//
// We assume UDFs are not idempotent and therefore cannot be retried. However,
// internal RRef control messages will be made idempotent and retryable.
//
// TODO: RRef internal messages are not yet idempotent
//
// 3. Out of Order Message Delivery
//
// We do not assume message delivery order between any pair of nodes, because
// both sender and receiver are using multiple threads. There is no guarantee on
// which message will be processed first.
//
// [RRef Lifetime]
//
// The goal of the protocol is to delete an OwnerRRef at an appropriate time.
// The right time to delete an OwnerRRef is when there are no living UserRRefs
// and Python GC also agrees to delete the OwnerRRef instance on the owner. The
// tricky part is to determine if there are any living UserRRefs.
//
// A user can get a UserRRef in three situations:
//
// (1). Receiving a UserRRef from the owner.
// (2). Receiving a UserRRef from another user.
// (3). Creating a new UserRRef owned by another worker.
//
// (1) is the simplest case where the owner initiates the fork, and hence it can
// easily increment local RC. The only requirement is that any UserRRef must
// notify the owner before destruction. Hence, we need the first guarantee:
//
// G1. The owner will be notified when any UserRRef is deleted.
//
// As messages might come delayed or out-of-order, we need more one guarantee to
// make sure the delete message is not sent out too soon. Let us first introduce
// a new concept. If A sends an RPC to B that involves an RRef, we call the RRef
// on A the parent RRef and the RRef on B the child RRef.
//
// G2. Parent RRef cannot be deleted until the child RRef is confirmed by the
// owner.
//
// Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply
// means that the user will not send out the delete message until all previous
// messages are ACKed. Note that ACKed does not mean the owner finishes
// executing the function, instead, it only means the owner has retrieved its
// local OwnerRRef and about to pass it to the function, which is sufficient to
// keep the OwnerRRef alive even if the delete message from the user arrives at
// the owner before the function finishes execution.
//
// With (2) and (3), it is possible that the owner only partially knows the RRef
// fork graph or not even knowing it at all. For example, the RRef could be
// constructed on a user, and before the owner receives the RPC call, the
// creator user might have already shared the RRef with other users, and those
// users could further share the RRef. One invariant is that the fork graph of
// any RRef is always a tree rooted at the owner, because forking an RRef always
// creates a new RRef instance, and hence every RRef has a single parent. One
// nasty detail is that when an RRef is created on a user, technically the owner
// is not its parent but we still consider it that way and it does not break the
// argument below.
//
// The owner's view on any node (fork) in the tree has three stages:
//
// 1) unknown -> 2) known -> 3) deleted.
//
// The owner's view on the entire tree keeps changing. The owner deletes its
// OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when
// OwnerRRef is deleted, all UserRRefs could be either indeed deleted or
// unknown. The dangerous case is when some forks are unknown and others are
// deleted.
//
// G2 trivially guarantees that no parent UserRRef Y can be deleted before the
// owner knows all of Y's children UserRRefs.
//
// However, it is possible that the child UserRRef Z may be deleted before the
// owner knows its parent Y. More specifically, this can happen when all of Z's
// messages are processed by the owner before all messages from Y, including the
// delete message. Nevertheless, this does not cause any problem. Because, at
// least one of Y's ancestor will be alive, and it will prevent the owner from
// deleting the OwnerRRef. Consider the following example: (NB: this scenario
// will no longer relevant when we block UDF until all RRefs are confirmed by
// the owner)
//
// OwnerRRef -> A -> Y -> Z
//
// OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted
// without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as
// the owner directly forks the RRef to A. A won't die before the owner knows Y.
//
// Things get a little trickier if the RRef is created on a user:
//
// OwnerRRef
// ^
// |
// A -> Y -> Z
//
// If Z calls to_here on the UserRRef, the owner at least knows A when Z is
// deleted, because otherwise to_here wouldn't finish. If Z does not call
// to_here, it is possible that the owner receives all messages from Z before
// any message from A and Y. In this case, as the real data of the OwnerRRef has
// not been created yet, there is nothing to be deleted either. It is the same
// as Z does not exist at all Hence, it's still OK.
//
// See #26759 for more details and discussions.
//
// TODO: make RRef an IValue, and edit createStackForSchema accordingly
// TODO: make RRef system messages idempotent and retry on failures.
//
// ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
// Each ``RRef`` has a globally unique ``RRefId``.
class RRef {
public:
// RRef is made NOT copyable NOT movable to prevent messing up reference
// counting
// counting.
RRef(const RRef& other) = delete;
RRef(RRef&& other) = delete;
RRef& operator=(RRef&& other) = delete;
virtual ~RRef() = default;
worker_id_t owner() const;
const RRefId& id() const;
IValue fork() const;
// returns the worker id of the owner
inline worker_id_t owner() const {
return ownerId_;
}
// Returns the globally unique RRefId of this RRef
inline const RRefId& rrefId() const {
return rrefId_;
}
// Returns true if this is the ``OwnerRRef``
virtual bool isOwner() const = 0;
virtual IValue toHere() = 0;
// returns true if this RRef holds an py::object, false if IValue
virtual bool isPyObj() = 0;
protected:
friend class RRefContext;
RRef(worker_id_t ownerId, const RRefId& rrefId);
RRefForkData fork() const;
const worker_id_t ownerId_;
const RRefId rrefId_;
};
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
// never owns the real value, the only way to get the value of the ``RRef`` is
// to call ``to_here()`` and get a copy..
template <typename T>
class UserRRef final : public RRef {
public:
const ForkId& forkId() const;
bool isOwner() const override;
IValue toHere() override;
UserRRef(const UserRRef& other) = delete;
UserRRef(UserRRef&& other) = delete;
UserRRef& operator=(const UserRRef& other) = delete;
UserRRef& operator=(UserRRef&& other) = delete;
inline bool isOwner() const override {
return false;
}
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
// Returns the globally unique ForkId of this RRef
const ForkId& forkId() const;
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
// yet, this call will block.
T toHere();
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
~UserRRef() override;
private:
@ -93,28 +268,27 @@ class UserRRef final : public RRef {
template <typename T>
class OwnerRRef final : public RRef {
public:
bool isOwner() const override {
OwnerRRef(const OwnerRRef& other) = delete;
OwnerRRef(OwnerRRef&& other) = delete;
OwnerRRef& operator=(const OwnerRRef& other) = delete;
OwnerRRef& operator=(OwnerRRef&& other) = delete;
inline bool isOwner() const override {
return true;
}
T getValue() const {
// TODO: use callback to make this non-blocking
std::unique_lock<std::mutex> lock(mutex_);
valueCV_.wait(lock, [this] { return value_.has_value(); });
return value_.value();
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
void setValue(T&& value) {
{
std::lock_guard<std::mutex> lock(mutex_);
value_ = std::move(value);
}
valueCV_.notify_all();
}
// Get a constant reference of the real value. This method will block if the
// value is not ready. This method does not need GIL as it does not create
// any new py::object.
const T& getValue() const;
IValue toHere() override {
AT_ERROR("OwnerRRef does not support toHere(), use getValue() instead.");
}
// Set the value of this ``OwnerRRef``. This method does not need GIL as it
// does not create any new py::object.
void setValue(T&& value);
private:
friend class RRefContext;
@ -122,9 +296,6 @@ class OwnerRRef final : public RRef {
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId)
: OwnerRRef(ownerId, rrefId, {}) {}
OwnerRRef(OwnerRRef<T>&& other) noexcept
: OwnerRRef(other.owner(), other.id(), std::move(other.value_)) {}
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional<T> value)
: RRef(ownerId, rrefId) {
value_ = std::move(value);

View File

@ -1,10 +1,13 @@
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <sstream>
namespace torch {
namespace distributed {
namespace rpc {
std::unique_ptr<RRefContext> RRefContext::context_;
std::unique_ptr<RRefContext> RRefContext::context_ = nullptr;
void RRefContext::initInstance(std::shared_ptr<RpcAgent> agent) {
TORCH_CHECK(!RRefContext::context_, "Can only initialize RRefContext once.");
@ -20,50 +23,291 @@ std::unique_ptr<RRefContext>& RRefContext::getInstance() {
return RRefContext::context_;
}
void RRefContext::destroyInstance() {
RRefContext::getInstance()->checkRRefLeaks();
RRefContext::context_.reset();
}
void RRefContext::handleException(const Message& message) {
if (message.type() == MessageType::EXCEPTION) {
// TODO: allow users to register an error handler and call it here.
std::string err(message.payload().begin(), message.payload().end());
VLOG(1) << "Got exception: " << err << std::endl << std::flush;
throw std::runtime_error(err);
}
}
RRefContext::RRefContext(std::shared_ptr<RpcAgent> agent)
: agent_(std::move(agent)) {}
worker_id_t RRefContext::getWorkerId() const {
return agent_->getWorkerId().id_;
RRefContext::~RRefContext() {
if (!owners_.empty()) {
AutoGIL ag;
owners_.clear();
}
}
RRefId RRefContext::genRRefId() {
return RRefId(getWorkerId(), nextLocalId_++);
void RRefContext::checkRRefLeaks() {
if (!forks_.empty()) {
std::stringstream ss;
for (auto& entry : forks_) {
const RRefId& rrefId = entry.first;
for (const auto& forkId : entry.second) {
ss << "Leaking RRef " << rrefId << " with fork Id " << forkId
<< std::endl;
}
}
AT_ERROR(ss.str());
}
}
const std::shared_ptr<RpcAgent>& RRefContext::agent() const {
return agent_;
template <typename T>
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(worker_id_t ownerId) {
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
return createUserRRef<T>(
ownerId, genGloballyUniqueId(), genGloballyUniqueId());
}
void RRefContext::addFork(const at::IValue& value) {
auto rfd = RRefForkData::fromIValue(value);
AT_ASSERT(
rfd.ownerId_ == getWorkerId(),
"RRef user should never receive fork notification.");
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
worker_id_t ownerId);
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
py::object>(worker_id_t ownerId);
template <typename T>
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId) {
TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
// RRefContext does not track user RRefs, it will be destructed when there
// is no shared_ptrs pointing to it.
//
// NB: cannot use make_shared here as the constructor of UserRRef is private.
// NB: This UserRRef has not been confirmed by the owner yet. This function's
// call site is responsible for adding this UserRRef to pendingUsers_.
// Currently, there are two call sites.
// (1) The creator user in python_functions.cpp
// (2) The callee user in RRefContext::notifyOwnerAndParentOfFork.
//
// The reason for not adding the pending user here is to put addPendingUser()
// close to where the RPC occurs, and it is more clear to pair it with
// deletePendingUser() in the response callback at the call site.
return std::shared_ptr<UserRRef<T>>(new UserRRef<T>(ownerId, rrefId, forkId));
}
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId);
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
py::object>(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId);
template <typename T>
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(const RRefForkData& rfd) {
auto& ownerId = rfd.ownerId_;
auto& rrefId = rfd.rrefId_;
auto& forkId = rfd.forkId_;
if (ownerId == getWorkerId()) {
return getOrCreateOwnerRRef<T>(rrefId);
} else {
return createUserRRef<T>(ownerId, rrefId, forkId);
}
}
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<IValue>(
const RRefForkData& rfd);
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<py::object>(
const RRefForkData& rfd);
template <typename T>
std::shared_ptr<OwnerRRef<T>> RRefContext::getOrCreateOwnerRRef(
const RRefId& rrefId) {
std::lock_guard<std::mutex> lock(mutex_);
auto& rrefForks = forks_[rfd.rrefId_];
AT_ASSERT(
rrefForks.find(rfd.forkId_) == rrefForks.end(),
const auto iter = owners_.find(rrefId);
if (iter == owners_.end()) {
// Scenario (1) the first time this owner knows about this RRef
//
// NB: cannot use make_shared here as the constructor of OwnerRRef is
// private.
auto rref =
std::shared_ptr<OwnerRRef<T>>(new OwnerRRef<T>(getWorkerId(), rrefId));
owners_[rref->rrefId()] = rref;
return rref;
} else {
// Scenario (2) retrieving an existing RRef
return std::static_pointer_cast<OwnerRRef<T>>(iter->second);
}
}
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::getOrCreateOwnerRRef<
IValue>(const RRefId& rrefId);
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::
getOrCreateOwnerRRef<py::object>(const RRefId& rrefId);
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
auto rfd = rref->fork();
if (rref->isOwner()) {
// Note [Early Fork Registration]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// If the parent (caller) is the owner, directly register the fork, instead
// of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An
// Alternative is adding the fork when the callee user ACKs. However, before
// that, the owner still have to adds the OwnerRRef into some map to keep it
// alive (e.g., in pendingChildren_). Hence, adding the fork here or in the
// ACK does not making any difference but only add complexity.
// TODO: When adding failure retries and timeout, this fork needs to be
// deleted if the owner does not receive the ACK within the timeout.
addForkOfOwner(rfd.rrefId_, rfd.forkId_);
} else {
// Note [Useful Phantom Fork ID for User to Owner Call]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// If the callee of dist.remote or dist.rpc is the owner of this RRef, the
// callee will not create a fork using this rfd.forkId_, because the owner
// will only keep one `OwnerRRef` instance and will not create any
// `UserRRef` instances. However, this rfd.forkId_ is still necessary, as
// the caller user needs to keep this `UserRRef` alive until it gets the
// ACK from the callee owner. Otherwise, the delete message could arrive
// at the owner before this dist.rpc or dist.remote call, which could
// potentially trigger the `OwnerRRef` to be deleted before running the
// user code.
addPendingChild(rfd.forkId_, rref);
}
return rfd;
}
void RRefContext::notifyOwnerAndParentOfFork(
const ForkId& forkId,
worker_id_t parent,
const std::shared_ptr<RRef>& rref) {
if (parent == rref->owner()) {
// If the parent is the owner, this fork has already been added into the
// forks_ map when the owner sends the message to the callee user. Hence,
// it is not necessary to send another RREF_CHILD_ACCEPT or
// RREF_FORK_REQUEST back to the owner. See Note [Early Fork Registration].
return;
}
if (rref->isOwner()) {
// See Note [Useful Phantom Fork ID for User to Owner Call]
// In this case, the owner is the caller, and it does not add the fork id
// into forks_. Because, there will be no real `UserRRef` associated with
// this fork ID.
auto fm = agent_->send(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([](const Message& message) { handleException(message); });
} else {
auto fm = agent_->send(
agent_->getWorkerInfo(rref->owner()),
RRefForkRequest(rref->rrefId(), forkId).toMessage());
addPendingUser(forkId, rref);
fm->addCallback([this, forkId, parent](const Message& message) {
handleException(message);
this->finishForkRequest(forkId, parent);
});
}
}
void RRefContext::addPendingChild(
const ForkId& forkId,
const std::shared_ptr<RRef>& rref) {
// see Note [Early Fork Registration]
// If the parent is the owner, it should directly add the child UserRRef as a
// fork.
TORCH_INTERNAL_ASSERT(
!rref->isOwner(), "OwnerRRef should not have a pending child.");
std::lock_guard<std::mutex> lock(mutex_);
TORCH_INTERNAL_ASSERT(
pendingChildren_.find(forkId) == pendingChildren_.end(),
"Inconsistent states: attempt to add the same child fork twice.");
pendingChildren_[forkId] = rref;
}
void RRefContext::delPendingChild(const ForkId& forkId) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = pendingChildren_.find(forkId);
TORCH_INTERNAL_ASSERT(
iter != pendingChildren_.end(),
"Inconsistent states: attempt to delete a non-exist child fork.");
pendingChildren_.erase(iter);
}
void RRefContext::addPendingUser(
const ForkId& forkId,
const std::shared_ptr<RRef>& rref) {
TORCH_INTERNAL_ASSERT(
!rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
std::lock_guard<std::mutex> lock(mutex_);
TORCH_INTERNAL_ASSERT(
pendingUsers_.find(forkId) == pendingUsers_.end(),
"Inconsistent states: attempt to add the same UserRRef twice.");
pendingUsers_[forkId] = rref;
}
void RRefContext::delPendingUser(const ForkId& forkId) {
std::lock_guard<std::mutex> lock(mutex_);
auto iter = pendingUsers_.find(forkId);
TORCH_INTERNAL_ASSERT(
iter != pendingUsers_.end(),
"Inconsistent states: attempt to delete a non-exist UserRRef.");
pendingUsers_.erase(iter);
}
void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
delPendingUser(forkId);
auto fm = agent_->send(
agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
fm->addCallback([](const Message& message) { handleException(message); });
}
void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
std::lock_guard<std::mutex> lock(mutex_);
auto& rrefForks = forks_[rrefId];
TORCH_INTERNAL_ASSERT(
rrefForks.find(forkId) == rrefForks.end(),
"Got fork notification twice on the same RRef ",
rfd.rrefId_);
rrefForks.insert(rfd.forkId_);
forkId);
rrefForks.insert(forkId);
}
void RRefContext::delFork(const at::IValue& value) {
auto rfd = RRefForkData::fromIValue(value);
AT_ASSERT(
rfd.ownerId_ == getWorkerId(),
"RRef user should never receive delete notification.");
std::lock_guard<std::mutex> lock(mutex_);
auto& rrefForks = forks_[rfd.rrefId_];
AT_ASSERT(
rrefForks.find(rfd.forkId_) != rrefForks.end(),
"Attempt to delete a non-exist fork ",
rfd.forkId_);
rrefForks.erase(rfd.forkId_);
if (rrefForks.empty()) {
owners_.erase(rfd.rrefId_);
forks_.erase(rfd.rrefId_);
void RRefContext::delForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
std::shared_ptr<RRef> deletedRRef = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
auto rrefIter = forks_.find(rrefId);
TORCH_INTERNAL_ASSERT(
rrefIter != forks_.end(),
"Inconsistent states, deleting a fork before the owner knows it.");
auto& rrefForks = rrefIter->second;
auto forkIter = rrefForks.find(forkId);
TORCH_INTERNAL_ASSERT(
forkIter != rrefForks.end(),
"Attempt to delete a non-exist fork ",
forkId);
rrefForks.erase(forkId);
if (rrefForks.empty()) {
auto ownerIter = owners_.find(rrefId);
if (ownerIter != owners_.end()) {
deletedRRef = ownerIter->second;
owners_.erase(ownerIter);
}
forks_.erase(rrefIter);
}
}
if (deletedRRef && deletedRRef->isPyObj()) {
AutoGIL ag;
deletedRRef.reset();
}
}

View File

@ -17,90 +17,103 @@ class RRefContext {
public:
static void initInstance(std::shared_ptr<RpcAgent>);
static std::unique_ptr<RRefContext>& getInstance();
static void destroyInstance();
static void handleException(const Message& message);
RRefContext(const RRefContext&) = delete;
RRefContext(RRefContext&& other) = delete;
void operator=(const RRefContext&) = delete;
RRefContext& operator=(RRefContext&& other) = delete;
worker_id_t getWorkerId() const;
RRefId genRRefId();
const std::shared_ptr<RpcAgent>& agent() const;
~RRefContext();
// create a new RRef
// get the worker id of the current worker
inline worker_id_t getWorkerId() const {
return agent_->getWorkerInfo().id_;
}
// get the worker name of the current worker
inline const std::string& getWorkerName() const {
return agent_->getWorkerInfo().name_;
}
// generate a globally unique ID
inline GloballyUniqueId genGloballyUniqueId() {
return GloballyUniqueId(getWorkerId(), nextLocalId_++);
}
inline const std::shared_ptr<RpcAgent>& agent() const {
return agent_;
}
// create a ``UserRRef`` owned by the worker ``ownerId``
template <typename T>
std::shared_ptr<OwnerRRef<T>> createOwnerRRef(worker_id_t ownerId) {
TORCH_CHECK(ownerId == getWorkerId(), "Cannot create OwnerRRef on user.");
return getOrCreateOwnerRRef<T>(genRRefId());
}
std::shared_ptr<UserRRef<T>> createUserRRef(worker_id_t ownerId);
std::shared_ptr<UserRRef> createUserRRef(worker_id_t ownerId) {
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
return createUserRRef(ownerId, genRRefId(), genRRefId());
}
std::shared_ptr<UserRRef> createUserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId) {
TORCH_CHECK(
ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
// RRefContext does not track user RRefs, it will be destructed when there
// is no shared_ptrs pointing to it. NB: cannot use make_shared here as the
// constructor of UserRRef is private
return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId));
}
// get an existing RRef or create a new one from a serialized
// ``RRefForkData``.
// Convert an RRefForkData into an RRef. This RRef could be user or owner.
// This RRef could have already existed before, or could be created in this
// method.
template <typename T>
std::shared_ptr<RRef> getOrCreateRRef(at::IValue&& value) {
auto rfd = RRefForkData::fromIValue(std::move(value));
return getOrCreateRRef<T>(rfd.ownerId_, rfd.rrefId_, rfd.forkId_);
}
std::shared_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd);
// Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
// one.
template <typename T>
std::shared_ptr<RRef> getOrCreateRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId) {
if (ownerId == getWorkerId()) {
return getOrCreateOwnerRRef<T>(rrefId);
} else {
return createUserRRef(ownerId, rrefId, forkId);
}
}
std::shared_ptr<OwnerRRef<T>> getOrCreateOwnerRRef(const RRefId& rrefId);
template <typename T>
std::shared_ptr<OwnerRRef<T>> getOrCreateOwnerRRef(const RRefId& rrefId) {
std::lock_guard<std::mutex> lock(mutex_);
const auto iter = owners_.find(rrefId);
if (iter == owners_.end()) {
// Scenario (1) the first time this owner knows about this RRef
// Scenario (2) This owner is also the creator.
//
// NB: cannot use make_shared here as the constructor of OwnerRRef is
// private.
auto rref = std::shared_ptr<OwnerRRef<T>>(
new OwnerRRef<T>(getWorkerId(), rrefId));
owners_[rref->id()] = rref;
return rref;
// Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the
// ``OwnerRRef`` in a map to keep it alive.
void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
// Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the
// IValue or py::object. For the later, this method will acquire GIL.
void delForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
} else {
// Scenario (3) retrieving an existing RRef
return std::dynamic_pointer_cast<OwnerRRef<T>>(iter->second);
}
}
// Invoked when pickling an RRef to setup child/fork properly
RRefForkData prepareChildFork(const std::shared_ptr<RRef>& rref);
// Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and
// send RREF_CHILD_ACCEPT to the parent.
// NB: forkId is necessary here as the rref could be an OwnerRRef
void notifyOwnerAndParentOfFork(
const ForkId& forkId,
worker_id_t parent,
const std::shared_ptr<RRef>& rref);
void addFork(const at::IValue& value);
void delFork(const at::IValue& value);
// When a UserRRef is forked to another worker (user or owner), it is added
// into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT
// from the child.
// NB: This is necessary for both user and owner child. As we do not have FIFO
// communication between workers, we need this strategy to make sure that all
// previously submitted rpc/remote calls are acked before sending out the
// RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
// soon.
void addPendingChild(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
void delPendingChild(const ForkId& forkId);
// When a UserRRef is created, it is added into pendingUsers_ to be held alive
// until it receives RREF_USER_ACCEPT from the owner.
void addPendingUser(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
void delPendingUser(const ForkId& forkId);
private:
RRefContext(std::shared_ptr<RpcAgent>);
template <typename T>
std::shared_ptr<UserRRef<T>> createUserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId);
void finishForkRequest(const ForkId& forkId, worker_id_t parent);
// If there is any leak on any RRef, this method will throw an error.
void checkRRefLeaks();
static std::unique_ptr<RRefContext> context_;
static std::atomic<local_id_t> nextLocalId_;
const std::shared_ptr<RpcAgent> agent_;
std::mutex mutex_;
mutable std::mutex mutex_;
// Keep OwnerRRefs alive until there is no living UserRRefs.
std::unordered_map<RRefId, std::shared_ptr<RRef>, RRefId::Hash> owners_;
// Tracks known living UserRRefs of an OwnerRRef
@ -109,6 +122,26 @@ class RRefContext {
std::unordered_set<ForkId, ForkId::Hash>,
RRefId::Hash>
forks_;
// The follow two maps keep UserRRefs alive by holding a shared_ptr to the
// RRef instances. A UserRRef must be added into this map if any of the
// following two conditions is ture:
//
// (1) A UserRRef has not been accepted by owner yet.
//
// It can be used or shared, but cannot be deleted, and hence kept alive
// in this map. A message of type RREF_USER_ACCEPT will remove the
// corresponding RRef from this map.
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash> pendingUsers_;
// (2) A UserRRef has forked a child UserRRef which has not been accepted by
// the owner yet.
//
// In this case, this UserRRef cannot send out RREF_USER_DELETE message,
// as it could potentially trigger the OwnerRRef been deleted before the
// owner learns about the forked child.
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash>
pendingChildren_;
};
} // namespace rpc

View File

@ -0,0 +1,173 @@
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/jit/pickle.h>
#include <limits>
namespace torch {
namespace distributed {
namespace rpc {
namespace {
std::vector<IValue> toIValues(const Message& message, MessageType type) {
TORCH_INTERNAL_ASSERT(
type == message.type(),
"Expecting message of type ",
type,
", but got ",
message.type());
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
return value.toTuple()->elements();
}
Message fromIValues(std::vector<IValue> ivalues, MessageType type) {
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return Message(std::move(payload), std::move(tensor_table), type);
}
} // namespace
/////////////////////////// RRefMessageBase //////////////////////////////////
const RRefId& RRefMessageBase::rrefId() {
return rrefId_;
}
Message RRefMessageBase::toMessage() && {
return fromIValues({rrefId_.toIValue()}, type_);
}
at::IValue RRefMessageBase::fromMessage(
const Message& message,
MessageType type) {
auto values = toIValues(message, type);
TORCH_INTERNAL_ASSERT(
values.size() == 1, "ScriptUserDelete expects 1 IValue from message.");
return std::move(values.back());
}
/////////////////////////// ForkMessageBase //////////////////////////////////
const ForkId& ForkMessageBase::forkId() {
return forkId_;
}
Message ForkMessageBase::toMessage() && {
return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_);
}
std::pair<RRefId, ForkId> ForkMessageBase::fromMessage(
const Message& message,
MessageType type) {
auto ivalues = toIValues(message, type);
TORCH_INTERNAL_ASSERT(
ivalues.size() == 2, "ScriptUserDelete expects 2 IValue from message.");
return std::make_pair(
RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1]));
}
/////////////////////////// RRef Protocol //////////////////////////////////
std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
const Message& message) {
return c10::guts::make_unique<ScriptRRefFetchCall>(
RRefId::fromIValue(RRefMessageBase::fromMessage(
message, MessageType::SCRIPT_RREF_FETCH_CALL)));
}
std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
const Message& message) {
return c10::guts::make_unique<PythonRRefFetchCall>(
RRefId::fromIValue(RRefMessageBase::fromMessage(
message, MessageType::PYTHON_RREF_FETCH_CALL)));
}
const std::vector<at::IValue>& RRefFetchRet::values() {
return values_;
}
Message RRefFetchRet::toMessage() && {
std::vector<at::IValue> ivalues = values_;
std::vector<torch::Tensor> tensor_table;
auto payload =
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
return Message(
std::move(payload), std::move(tensor_table), MessageType::RREF_FETCH_RET);
}
std::unique_ptr<RRefFetchRet> RRefFetchRet::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value =
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto values = value.toTuple()->elements();
return c10::guts::make_unique<RRefFetchRet>(std::move(values));
}
std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
const Message& message) {
auto pair =
ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
return c10::guts::make_unique<RRefUserDelete>(
RRefUserDelete(pair.first, pair.second));
}
std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET);
return c10::guts::make_unique<RemoteRet>(pair.first, pair.second);
}
const ForkId& RRefChildAccept::forkId() const {
return forkId_;
}
Message RRefChildAccept::toMessage() && {
return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT);
}
std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage(
const Message& message) {
auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT);
TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message.");
return c10::guts::make_unique<RRefChildAccept>(
ForkId::fromIValue(values.back()));
}
std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage(
const Message& message) {
auto pair =
ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST);
return c10::guts::make_unique<RRefForkRequest>(pair.first, pair.second);
}
Message RRefAck::toMessage() && {
return Message({}, {}, MessageType::RREF_ACK);
}
std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) {
TORCH_INTERNAL_ASSERT(
message.type() == MessageType::RREF_ACK,
"Message type miss match, expect ",
MessageType::RREF_ACK,
", but got ",
message.type());
return c10::guts::make_unique<RRefAck>();
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,137 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
// Temporary solution of RRef operations.
// TODO: Remove all these messages and use rpc + registered functions instead.
class TORCH_API RRefMessageBase : public RpcCommandBase {
public:
RRefMessageBase(const RRefId& rrefId, MessageType type)
: rrefId_(rrefId), type_(type) {}
virtual ~RRefMessageBase() override = default;
const RRefId& rrefId();
virtual Message toMessage() && override;
static at::IValue fromMessage(const Message& message, MessageType type);
protected:
const RRefId rrefId_;
const MessageType type_;
};
class TORCH_API ForkMessageBase : public RRefMessageBase {
public:
ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type)
: RRefMessageBase(rrefId, type), forkId_(forkId) {}
virtual ~ForkMessageBase() override = default;
const ForkId& forkId();
virtual Message toMessage() && override;
static std::pair<RRefId, ForkId> fromMessage(
const Message& message,
MessageType type);
protected:
const ForkId forkId_;
};
// UserRRef uses this message to fetch the remote RRef value from the owner.
class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase {
public:
explicit ScriptRRefFetchCall(const RRefId& rrefId)
: RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL) {}
static std::unique_ptr<ScriptRRefFetchCall> fromMessage(
const Message& message);
};
class TORCH_API PythonRRefFetchCall final : public RRefMessageBase {
public:
explicit PythonRRefFetchCall(const RRefId& rrefId)
: RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL) {}
static std::unique_ptr<PythonRRefFetchCall> fromMessage(
const Message& message);
};
// OwnerRRef uses this message to send the RRef value to a remote UserRRef
class TORCH_API RRefFetchRet final : public RpcCommandBase {
public:
explicit RRefFetchRet(std::vector<at::IValue> values)
: values_(std::move(values)) {}
const std::vector<at::IValue>& values();
Message toMessage() && override;
static std::unique_ptr<RRefFetchRet> fromMessage(const Message& message);
private:
std::vector<at::IValue> values_;
};
// UserRRef (regardless it's the creator or not) uses this message to notiify
// OwnerRRef on delete.
class TORCH_API RRefUserDelete final : public ForkMessageBase {
public:
RRefUserDelete(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {}
static std::unique_ptr<RRefUserDelete> fromMessage(const Message& message);
};
class TORCH_API RemoteRet final : public ForkMessageBase {
public:
RemoteRet(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {}
static std::unique_ptr<RemoteRet> fromMessage(const Message& message);
};
// A child RRef uses this message to notify its parent that the child has been
// confirmed by the owner.
class TORCH_API RRefChildAccept final : public RpcCommandBase {
public:
explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {}
const ForkId& forkId() const;
Message toMessage() && override;
static std::unique_ptr<RRefChildAccept> fromMessage(const Message& message);
private:
const ForkId forkId_;
};
// A child RRef uses this message to send a fork request to the owner.
class TORCH_API RRefForkRequest final : public ForkMessageBase {
public:
RRefForkRequest(const RRefId& rrefId, const ForkId& forkId)
: ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {}
static std::unique_ptr<RRefForkRequest> fromMessage(const Message& message);
};
class TORCH_API RRefAck final : public RpcCommandBase {
public:
RRefAck() {}
Message toMessage() && override;
static std::unique_ptr<RRefAck> fromMessage(const Message& message);
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -9,32 +9,26 @@ namespace rpc {
ScriptRemoteCall::ScriptRemoteCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& args,
at::IValue retRRefId,
at::IValue retForkId)
const RRefId& retRRefId,
const ForkId& retForkId)
: ScriptCall(std::move(op), std::move(args)),
retRRefId_(std::move(retRRefId)),
retForkId_(std::move(retForkId)) {}
const at::IValue& ScriptRemoteCall::retRRefId() {
return retRRefId_;
}
const at::IValue& ScriptRemoteCall::retForkId() {
return retForkId_;
}
retRRefId_(retRRefId),
retForkId_(retForkId) {}
Message ScriptRemoteCall::toMessage() && {
std::vector<IValue> ivalues;
ScriptCall::toIValues(ivalues);
ivalues.push_back(retRRefId_);
ivalues.push_back(retForkId_);
ivalues.emplace_back(retRRefId_.toIValue());
ivalues.emplace_back(retForkId_.toIValue());
std::vector<torch::Tensor> tensor_table;
auto payload =
jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return Message(
std::move(payload), std::move(tensor_table), MessageType::REMOTE_CALL);
std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_REMOTE_CALL);
}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
@ -47,9 +41,9 @@ std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
auto values = value.toTuple()->elements();
// remove the last element from values and convert it back to an RRef
auto retForkId = std::move(values.back());
auto retForkId = RRefId::fromIValue(values.back());
values.pop_back();
auto retRRefId = std::move(values.back());
auto retRRefId = ForkId::fromIValue(values.back());
values.pop_back();
auto op = ScriptCall::fromIValues(values);

View File

@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <vector>
@ -11,26 +12,32 @@ namespace rpc {
using torch::jit::Operator;
// A ScriptCall instance represents an invocation of a builtin operator for a
// TorchScript function (not implemented yet). If it is a builtin operator, it
// contains a shared ptr to the `Operator` and a list of arguments.
// A ScriptRemoteCall instance represents an invocation of `dist.remote` on a
// builtin operator. Currently, it does not support using RRef as arguments yet.
// Besides the operator and a vector of arguments, ScriptRemoteCall also
// caontains the RRefId and the ForkId of the return value RRef.
class TORCH_API ScriptRemoteCall final : public ScriptCall {
public:
ScriptRemoteCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& args,
at::IValue retRRefId,
at::IValue retForkId);
const RRefId& retRRefId,
const ForkId& retForkId);
const at::IValue& retRRefId();
const at::IValue& retForkId();
inline const RRefId& retRRefId() const {
return retRRefId_;
}
inline const ForkId& retForkId() const {
return retForkId_;
}
Message toMessage() && override;
static std::unique_ptr<ScriptRemoteCall> fromMessage(const Message& message);
private:
const at::IValue retRRefId_;
const at::IValue retForkId_;
const RRefId retRRefId_;
const ForkId retForkId_;
};
} // namespace rpc

View File

@ -4,6 +4,17 @@ namespace torch {
namespace distributed {
namespace rpc {
static_assert(
std::numeric_limits<local_id_t>::max() <=
std::numeric_limits<int64_t>::max(),
"The max value of local_id_t must be within the range of int64_t");
static_assert(
std::numeric_limits<worker_id_t>::max() <=
std::numeric_limits<int64_t>::max(),
"The max value of worker_id_t must be within the range of int64_t");
/////////////////////////// GloballyUniqueId ///////////////////////////
GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId)
: createdOn_(createdOn), localId_(localId) {}
@ -16,8 +27,8 @@ bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const {
}
at::IValue GloballyUniqueId::toIValue() const {
std::vector<at::IValue> ivalues = {(int64_t)createdOn_, (int64_t)localId_};
return c10::ivalue::Tuple::create(std::move(ivalues));
return c10::ivalue::Tuple::create(
{static_cast<int64_t>(createdOn_), static_cast<int64_t>(localId_)});
}
GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
@ -28,18 +39,17 @@ GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
"expects a GenericList of two elements, but got ",
ivalues.size());
worker_id_t createdOn = ivalues[0].toInt();
local_id_t localId = ivalues[1].toInt();
TORCH_CHECK(
createdOn < std::numeric_limits<worker_id_t>::max(),
ivalues[0].toInt() <= std::numeric_limits<worker_id_t>::max(),
"GloballyUniqueId createdOn out of range, got ",
createdOn);
ivalues[0].toInt());
worker_id_t createdOn = ivalues[0].toInt();
TORCH_CHECK(
localId < std::numeric_limits<local_id_t>::max(),
ivalues[1].toInt() <= std::numeric_limits<local_id_t>::max(),
"GloballyUniqueId localId out of range, got ",
localId);
ivalues[1].toInt());
local_id_t localId = ivalues[1].toInt();
return GloballyUniqueId(createdOn, localId);
}
@ -49,6 +59,29 @@ std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) {
<< globalId.localId_ << ")";
}
/////////////////////////// SerializedPyObj ///////////////////////////
std::vector<at::IValue> SerializedPyObj::toIValues() const {
std::vector<at::IValue> ivalues;
ivalues.reserve(tensors_.size() + 1);
for (auto& tensor : tensors_) {
ivalues.emplace_back(tensor);
}
ivalues.emplace_back(payload_);
return ivalues;
}
SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) {
std::string payload = values.back().toStringRef();
values.pop_back();
std::vector<at::Tensor> tensors;
tensors.reserve(values.size());
for (auto& value : values) {
tensors.emplace_back(value.toTensor());
}
return SerializedPyObj(std::move(payload), std::move(tensors));
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -8,11 +8,12 @@ namespace distributed {
namespace rpc {
using worker_id_t = int16_t;
using local_id_t = uint64_t;
using local_id_t = int64_t;
struct GloballyUniqueId final {
struct TORCH_API GloballyUniqueId final {
GloballyUniqueId(worker_id_t createdOn, local_id_t localId);
GloballyUniqueId(const GloballyUniqueId& other) = default;
GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete;
bool operator==(const GloballyUniqueId& other) const;
bool operator!=(const GloballyUniqueId& other) const;
@ -32,11 +33,24 @@ struct GloballyUniqueId final {
const local_id_t localId_;
};
std::ostream& operator<<(std::ostream& os, const GloballyUniqueId& globalId);
TORCH_API std::ostream& operator<<(
std::ostream& os,
const GloballyUniqueId& globalId);
using RRefId = GloballyUniqueId;
using ForkId = GloballyUniqueId;
struct TORCH_API SerializedPyObj final {
SerializedPyObj(std::string&& payload, std::vector<at::Tensor>&& tensors)
: payload_(std::move(payload)), tensors_(std::move(tensors)) {}
std::vector<at::IValue> toIValues() const;
static SerializedPyObj fromIValues(std::vector<at::IValue> value);
const std::string payload_;
const std::vector<at::Tensor> tensors_;
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,11 +1,12 @@
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_udf_call.h>
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/script_rref_proto.h>
namespace torch {
namespace distributed {
@ -19,17 +20,26 @@ std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
case MessageType::PYTHON_CALL: {
return PythonUDFCall::fromMessage(request);
}
case MessageType::REMOTE_CALL: {
case MessageType::SCRIPT_REMOTE_CALL: {
return ScriptRemoteCall::fromMessage(request);
}
case MessageType::RREF_FETCH_CALL: {
case MessageType::PYTHON_REMOTE_CALL: {
return PythonRemoteCall::fromMessage(request);
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
return ScriptRRefFetchCall::fromMessage(request);
}
case MessageType::RREF_USER_CREATE: {
return ScriptRRefCreate::fromMessage(request);
case MessageType::PYTHON_RREF_FETCH_CALL: {
return PythonRRefFetchCall::fromMessage(request);
}
case MessageType::RREF_USER_DELETE: {
return ScriptRRefDelete::fromMessage(request);
return RRefUserDelete::fromMessage(request);
}
case MessageType::RREF_CHILD_ACCEPT: {
return RRefChildAccept::fromMessage(request);
}
case MessageType::RREF_FORK_REQUEST: {
return RRefForkRequest::fromMessage(request);
}
case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
return RpcWithAutograd::fromMessage(request);
@ -49,6 +59,15 @@ std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
case MessageType::PYTHON_RET: {
return PythonUDFResp::fromMessage(response);
}
case MessageType::REMOTE_RET: {
return RemoteRet::fromMessage(response);
}
case MessageType::RREF_FETCH_RET: {
return RRefFetchRet::fromMessage(response);
}
case MessageType::RREF_ACK: {
return RRefAck::fromMessage(response);
}
case MessageType::EXCEPTION: {
std::string err(response.payload().begin(), response.payload().end());
throw std::runtime_error(err);

View File

@ -54,4 +54,4 @@ if is_available():
"""
_init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads)
from .rpc_api import _agent
autograd._init(_agent.get_worker_id().id)
autograd._init(_agent.get_worker_info().id)

View File

@ -100,6 +100,8 @@ class _InternalRPCPickler:
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
_internal_rpc_pickler = _InternalRPCPickler()
def serialize(obj):
return _internal_rpc_pickler.serialize(obj)
def run_python_udf_internal(pickled_python_udf, tensors):
r"""
@ -114,7 +116,8 @@ def run_python_udf_internal(pickled_python_udf, tensors):
# except str = exception info + traceback string
except_str = "{}\n{}".format(repr(e), traceback.format_exc())
result = RemoteException(except_str)
return _internal_rpc_pickler.serialize(result)
# return _internal_rpc_pickler.serialize(result)
return result
def load_python_udf_result_internal(pickled_python_result, tensors):

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python3
from . import invoke_rpc_builtin, invoke_rpc_python_udf, invoke_remote_builtin
from . import init_rref_context
from . import invoke_rpc_builtin, invoke_rpc_python_udf
from . import invoke_remote_builtin, invoke_remote_python_udf
from . import _init_rref_context, _destroy_rref_context
from . import ProcessGroupAgent
from . import WorkerId
from .rpc_backend_registry import is_rpc_backend_registered, init_rpc_backend
from . import WorkerInfo
from .internal_rpc_utils import _internal_rpc_pickler, PythonUDF
from .rpc_backend_registry import is_rpc_backend_registered, init_rpc_backend
import functools
import sys
@ -38,6 +39,7 @@ def join_rpc():
if _agent:
_agent.join()
_agent = None
_destroy_rref_context()
@_require_initialized
@ -78,40 +80,41 @@ def _init_rpc(backend=RpcBackend.PROCESS_GROUP,
self_rank, group.rank()))
# TODO: add try-except and destroy _agent in all processes if any fails.
_agent = ProcessGroupAgent(self_name, group, num_send_recv_threads)
init_rref_context(_agent)
_init_rref_context(_agent)
elif is_rpc_backend_registered(backend):
_agent = init_rpc_backend(
backend,
self_rank=self_rank,
self_name=self_name,
init_method=init_method,
init_method=init_method
)
init_rref_context(_agent)
_init_rref_context(_agent)
else:
raise RuntimeError("Unrecognized RPC backend ", backend)
@_require_initialized
def get_worker_id(worker_name=None):
def get_worker_info(worker_name=None):
r"""
Get worker id of a given worker name. Use this worker id to avoid passing
an expensive string to ``rpc`` on every invocation.
Get WorkerInfo of a given worker name. Use this WorkerInfo to avoid passing
an expensive string to ``rpc`` on every invocation. The WorkerInfo contains
the name of the worker and the id of the worker.
Arguments:
worker_name (str): the string name of a worker. If ``None``, return the
the id of the current worker. (default ``None``)
"""
if worker_name:
return _agent.get_worker_id(worker_name)
return _agent.get_worker_info(worker_name)
else:
return _agent.get_worker_id()
return _agent.get_worker_info()
def _to_worker_id(name_or_id):
if isinstance(name_or_id, WorkerId):
def _to_worker_info(name_or_id):
if isinstance(name_or_id, WorkerInfo):
return name_or_id
elif isinstance(name_or_id, str):
return get_worker_id(name_or_id)
return get_worker_info(name_or_id)
else:
raise ValueError("Unsupported RPC worker ID type {}".format(name_or_id))
@ -142,7 +145,7 @@ def remote(to, func, args=None, kwargs=None):
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_rpc("worker0")
>>> worker1 = dist.get_worker_id("worker1")
>>> worker1 = dist.get_worker_info("worker1")
>>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
@ -159,8 +162,15 @@ def remote(to, func, args=None, kwargs=None):
args = args if args else ()
kwargs = kwargs if kwargs else {}
return invoke_remote_builtin(
_agent, _to_worker_id(to), qualified_name, *args, **kwargs)
info = _to_worker_info(to)
if qualified_name is not None:
return invoke_remote_builtin(
_agent, info, qualified_name, *args, **kwargs)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
return invoke_remote_python_udf(
_agent, info, pickled_python_udf, tensors)
def _invoke_rpc(to, func, args=None, kwargs=None):
@ -172,15 +182,16 @@ def _invoke_rpc(to, func, args=None, kwargs=None):
args = args if args else ()
kwargs = kwargs if kwargs else {}
info = _to_worker_info(to)
if qualified_name is not None:
fut = invoke_rpc_builtin(
_agent, _to_worker_id(to), qualified_name, *args, **kwargs
_agent, info, qualified_name, *args, **kwargs
)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
fut = invoke_rpc_python_udf(
_agent, _to_worker_id(to), pickled_python_udf, tensors)
_agent, info, pickled_python_udf, tensors)
return fut
@ -314,7 +325,7 @@ def rpc(to, func, args=None, kwargs=None, async_call=False):
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> worker1 = dist.get_worker_id("worker1")
>>> worker1 = dist.get_worker_info("worker1")
>>> fut1 = dist.rpc(worker1, torch.add, args=(torch.ones(2), 3), async_call=True)
>>> fut2 = dist.rpc(worker1, min, args=(1, 2), async_call=True)
>>> result = fut1.wait() + fut2.wait()
@ -330,6 +341,7 @@ def rpc(to, func, args=None, kwargs=None, async_call=False):
"""dist.rpc is deprecated. Use dist.rpc_async for asynchronous
calls or dist.rpc_sync for synchronous calls instead."""
)
if async_call:
return rpc_async(to, func, args, kwargs)
else: