[RPC] Create local RRef<ModuleInterface> remotely in Python, use it remotely in TorchScript (#34183)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34183

https://github.com/pytorch/pytorch/pull/33263 enhanced the RRef Python constructor to infer most types, by `jit::tryToInferType(..)`.

But this helper function can't infer `ScriptModule` type due to `ScriptModule`'s special per-Module type singleton logic, so it's still not possible for an Python-created RRef to know the JIT type of it's contained `ScriptModule`.

Instead of inferring the specific type of a Module, which could leads to too many candidate types (due to Module's multiple inheritance possibility), it's more straightforward to set it's type as a user-specified `ModuleInterface` type.

We added an optional argument `type_hint` for users to mark an `RRef` for what `ModuleInterface` type it's holds.

ghstack-source-id: 99649379

(Note: this ignores all push blocking failures!)

Test Plan:
Aspects that need to be confirmed in the test cases

https://fb.quip.com/aGxRAh2lCg05

```
buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_class_rref

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_module_rref

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_class_rref_in_py_and_use_in_script

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_module_rref_in_py_and_use_in_script

buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_torchscript_function_exception
```

Differential Revision: D7065050

fbshipit-source-id: e10210c0996622969e499e4a35b0659b36787c1c
This commit is contained in:
Shihao Xu
2020-03-06 08:24:54 -08:00
committed by Facebook Github Bot
parent a7da4490cc
commit 17ceb6941f
4 changed files with 189 additions and 77 deletions

View File

@ -149,7 +149,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
>>> # count is automatically updated.
>>> rpc.rpc_sync("worker1", f, args(rref,))
)")
.def(py::init<const py::object&>())
.def(
py::init<const py::object&, const py::object&>(),
py::arg("value"),
py::arg("type_hint") = py::none())
.def(
// not releasing GIL here to avoid context switch on getters
"is_owner",

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
@ -46,6 +47,50 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
TypePtr tryInferTypeWithTypeHint(
const py::object& value,
const py::object& type_hint) {
// If the py::object to be contained by the RRef is a ScripModule, we enforce
// users to specify its ModuleInterface type.
if (auto module = jit::script::as_module(value)) {
TORCH_CHECK(
!type_hint.is_none(),
"The RRef being created contains a ScriptModule, "
"must provide its ModuleInterface type hint. ");
c10::QualifiedName type_qualified_name = c10::QualifiedName(
py::cast<std::string>(py::module::import("torch.jit")
.attr("_qualified_name")(type_hint)));
TypePtr type_hint_ptr =
jit::get_python_cu()->get_interface(type_qualified_name);
TORCH_CHECK(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOf(type_hint_ptr),
module.value().type()->python_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),
", did you pass a valid interface type?");
return type_hint_ptr;
} else {
TORCH_CHECK(
type_hint.is_none(),
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
}
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
// excluding ScripModule.
jit::InferredType type_inferred = jit::tryToInferType(value);
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
return type_inferred.type();
}
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject.
return PyObjectType::get();
} // namespace
} // namespace
/////////////////////////// PyRRef //////////////////////////////////
@ -54,19 +99,9 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
PyRRef::PyRRef(const py::object& value)
: PyRRef([&value]() {
jit::InferredType type_inferred = jit::tryToInferType(value);
TypePtr elem_type = nullptr;
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
elem_type = type_inferred.type();
} else {
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject
elem_type = PyObjectType::get();
}
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
: PyRRef([&value, &type_hint]() {
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
py::object copy(value); // increases refcount
IValue ivalue = jit::toIValue(std::move(copy), elem_type);

View File

@ -12,7 +12,7 @@ namespace rpc {
// pickle and unpickle.
class PyRRef {
public:
explicit PyRRef(const py::object& value);
explicit PyRRef(const py::object& value, const py::object& type_hint);
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
bool isOwner() const;

View File

@ -10,25 +10,10 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
)
def python_function():
return 0
def rpc_return_rref(dst):
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
# Define Script functions on both client and server sides.
@torch.jit.script
def no_arg():
return 0
@torch.jit.script
def one_arg(value):
return value + 1
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
def __init__(self, dst_worker):
super().__init__()
@ -48,8 +33,13 @@ class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
@torch.jit.script
class MyScriptClass:
def __init__(self):
self.a = 10
def __init__(self, a):
# type: (int)
self.a = a
def get_value(self):
# type: () -> int
return self.a
@torch.jit.interface
@ -70,47 +60,94 @@ class MyScriptModule(torch.jit.ScriptModule):
return self.a
@torch.jit.script
def rref_to_here(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_var.to_here()
def owner_create_rref_my_script_class(a):
return rpc.RRef(MyScriptClass(a))
def owner_create_rref_my_script_module(a):
return rpc.RRef(MyScriptModule(a), MyModuleInterface)
@torch.jit.script
def return_rref(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
@torch.jit.ignore
def my_script_module_init(rank):
# type: (int) -> MyModuleInterface
return MyScriptModule(rank)
def script_run_get_value_rref_my_script_class(rref):
# type: (RRef[MyScriptClass]) -> int
return rref.to_here().get_value()
@torch.jit.script
def construct_my_script_module(rank):
# type: (int) -> MyModuleInterface
return my_script_module_init(rank)
def script_run_forward_rref_my_script_module(rref):
# type: (RRef[MyModuleInterface]) -> Tensor
return rref.to_here().forward()
class LocalRRefTest(RpcAgentTestFixture):
@dist_init
def test_create_local_script_class_rref_in_py(self):
if self.rank != 0:
return
# Create a local RRef<MyScriptClass>.
rref_script_class = rpc.RRef(MyScriptClass(self.rank, ))
ret = rref_script_class.to_here().get_value()
self.assertEqual(ret, self.rank)
@dist_init
def test_create_local_script_module_rref_in_py(self):
if self.rank != 0:
return
# Create a local RRef<MyModuleInterface>.
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
ret = rref_script_module.to_here().forward()
self.assertEqual(ret, torch.ones(self.rank))
# Create a local RRef<MyModuleInterface> without type hint.
with self.assertRaisesRegex(
RuntimeError, (
"The RRef being created contains a ScriptModule, "
"must provide its ModuleInterface type hint."
)
):
rref_script_module = rpc.RRef(MyScriptModule(self.rank))
@dist_init
def test_return_local_script_class_rref_in_py_and_use_in_script(self):
if self.rank != 0:
return
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
# Create a local RRef<MyScripClass> remotely in Python.
rref = rpc.rpc_sync(dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,))
# Use RRef<MyScripClass> remotely in Script.
ret = rpc.rpc_sync(
rref.owner(), script_run_get_value_rref_my_script_class, args=(rref,)
)
self.assertEqual(ret, self.rank)
@dist_init
def test_return_local_script_module_rref_in_py_and_use_in_script(self):
if self.rank != 0:
return
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
# Create a local RRef<MyModuleInterface> remotely in Python.
rref = rpc.rpc_sync(dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,))
# Use RRef<MyModuleInterface> remotely in Script.
ret = rpc.rpc_sync(
rref.owner(), script_run_forward_rref_my_script_module, args=(rref,)
)
self.assertEqual(ret, torch.ones(self.rank))
def python_function():
return 0
@torch.jit.script
def run_ref_script_module(ref_script_module, t):
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
module = ref_script_module.to_here()
return module.forward() + t
@torch.jit.ignore
def rref_python_annotation(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
@torch.jit.script
def rref_script_annotation(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_python_annotation(rref_var).to_here()
def no_arg():
return 0
@torch.jit.script
@ -450,10 +487,58 @@ class JitRpcAsyncOpTest:
self.assertEqual(ret, 0)
@torch.jit.script
def one_arg(value):
return value + 1
@torch.jit.script
def rref_to_here(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_var.to_here()
@torch.jit.script
def return_rref(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
@torch.jit.ignore
def my_script_module_init(rank):
# type: (int) -> MyModuleInterface
return MyScriptModule(rank)
@torch.jit.script
def construct_my_script_module(rank):
# type: (int) -> MyModuleInterface
return my_script_module_init(rank)
@torch.jit.script
def run_ref_script_module(ref_script_module, t):
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
module = ref_script_module.to_here()
return module.forward() + t
@torch.jit.ignore
def rref_python_annotation(rref_var):
# type: (RRef[Tensor]) -> RRef[Tensor]
return rref_var
@torch.jit.script
def rref_script_annotation(rref_var):
# type: (RRef[Tensor]) -> Tensor
return rref_python_annotation(rref_var).to_here()
@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed rpc package does not support python2"
)
class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
class JitRpcTest(LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixture):
@dist_init
def test_torchscript_function(self):
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
@ -484,7 +569,7 @@ class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
# rpc_sync still accepts script class and run it in
# the same code path as python call.
ret = rpc.rpc_sync(
dst_worker_name, MyScriptClass, args=()
dst_worker_name, MyScriptClass, args=(self.rank,)
)
# rpc_sync does not accept script module and script module method.
@ -588,14 +673,3 @@ class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
res = rref_script_annotation(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
@dist_init
def test_local_rref_creation_with_ivalue(self):
# create a local RRef that holds a IValue
rref_local_script_class = rpc.RRef(MyScriptClass())
self.assertEqual(rref_local_script_class.to_here().a, 10)
# create a local RRef that holds a ScriptModule
rref_local_script_mod = rpc.RRef(MyScriptModule(3)._c)
self.assertEqual(rref_local_script_mod.to_here().forward(), torch.ones(3))