mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RPC] Create local RRef<ModuleInterface> remotely in Python, use it remotely in TorchScript (#34183)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34183 https://github.com/pytorch/pytorch/pull/33263 enhanced the RRef Python constructor to infer most types, by `jit::tryToInferType(..)`. But this helper function can't infer `ScriptModule` type due to `ScriptModule`'s special per-Module type singleton logic, so it's still not possible for an Python-created RRef to know the JIT type of it's contained `ScriptModule`. Instead of inferring the specific type of a Module, which could leads to too many candidate types (due to Module's multiple inheritance possibility), it's more straightforward to set it's type as a user-specified `ModuleInterface` type. We added an optional argument `type_hint` for users to mark an `RRef` for what `ModuleInterface` type it's holds. ghstack-source-id: 99649379 (Note: this ignores all push blocking failures!) Test Plan: Aspects that need to be confirmed in the test cases https://fb.quip.com/aGxRAh2lCg05 ``` buck test mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_class_rref buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_create_local_script_module_rref buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_class_rref_in_py_and_use_in_script buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_return_local_script_module_rref_in_py_and_use_in_script buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \ && buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par -r test_torchscript_function_exception ``` Differential Revision: D7065050 fbshipit-source-id: e10210c0996622969e499e4a35b0659b36787c1c
This commit is contained in:
committed by
Facebook Github Bot
parent
a7da4490cc
commit
17ceb6941f
@ -149,7 +149,10 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
||||
>>> # count is automatically updated.
|
||||
>>> rpc.rpc_sync("worker1", f, args(rref,))
|
||||
)")
|
||||
.def(py::init<const py::object&>())
|
||||
.def(
|
||||
py::init<const py::object&, const py::object&>(),
|
||||
py::arg("value"),
|
||||
py::arg("type_hint") = py::none())
|
||||
.def(
|
||||
// not releasing GIL here to avoid context switch on getters
|
||||
"is_owner",
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/distributed/rpc/python_functions.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#include <torch/csrc/jit/python/module_python.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
@ -46,6 +47,50 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
||||
|
||||
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
||||
}
|
||||
|
||||
TypePtr tryInferTypeWithTypeHint(
|
||||
const py::object& value,
|
||||
const py::object& type_hint) {
|
||||
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
||||
// users to specify its ModuleInterface type.
|
||||
if (auto module = jit::script::as_module(value)) {
|
||||
TORCH_CHECK(
|
||||
!type_hint.is_none(),
|
||||
"The RRef being created contains a ScriptModule, "
|
||||
"must provide its ModuleInterface type hint. ");
|
||||
c10::QualifiedName type_qualified_name = c10::QualifiedName(
|
||||
py::cast<std::string>(py::module::import("torch.jit")
|
||||
.attr("_qualified_name")(type_hint)));
|
||||
TypePtr type_hint_ptr =
|
||||
jit::get_python_cu()->get_interface(type_qualified_name);
|
||||
TORCH_CHECK(
|
||||
type_hint_ptr != nullptr &&
|
||||
module.value().type()->isSubtypeOf(type_hint_ptr),
|
||||
module.value().type()->python_str(),
|
||||
" is not a subtype of the type hint: ",
|
||||
type_qualified_name.qualifiedName(),
|
||||
", did you pass a valid interface type?");
|
||||
return type_hint_ptr;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
type_hint.is_none(),
|
||||
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
|
||||
}
|
||||
|
||||
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
|
||||
// excluding ScripModule.
|
||||
jit::InferredType type_inferred = jit::tryToInferType(value);
|
||||
if (type_inferred.success()) {
|
||||
// If we could infer the type from the pyobject, we create
|
||||
// the RRef with the IValue of that type.
|
||||
return type_inferred.type();
|
||||
}
|
||||
|
||||
// Otherwise it's a pure pyobject, create the RRef
|
||||
// that holds an IValue of an pyobject.
|
||||
return PyObjectType::get();
|
||||
} // namespace
|
||||
|
||||
} // namespace
|
||||
|
||||
/////////////////////////// PyRRef //////////////////////////////////
|
||||
@ -54,19 +99,9 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
||||
}
|
||||
|
||||
PyRRef::PyRRef(const py::object& value)
|
||||
: PyRRef([&value]() {
|
||||
jit::InferredType type_inferred = jit::tryToInferType(value);
|
||||
TypePtr elem_type = nullptr;
|
||||
if (type_inferred.success()) {
|
||||
// If we could infer the type from the pyobject, we create
|
||||
// the RRef with the IValue of that type.
|
||||
elem_type = type_inferred.type();
|
||||
} else {
|
||||
// Otherwise it's a pure pyobject, create the RRef
|
||||
// that holds an IValue of an pyobject
|
||||
elem_type = PyObjectType::get();
|
||||
}
|
||||
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
|
||||
: PyRRef([&value, &type_hint]() {
|
||||
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
|
||||
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
|
||||
py::object copy(value); // increases refcount
|
||||
IValue ivalue = jit::toIValue(std::move(copy), elem_type);
|
||||
|
@ -12,7 +12,7 @@ namespace rpc {
|
||||
// pickle and unpickle.
|
||||
class PyRRef {
|
||||
public:
|
||||
explicit PyRRef(const py::object& value);
|
||||
explicit PyRRef(const py::object& value, const py::object& type_hint);
|
||||
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
|
||||
|
||||
bool isOwner() const;
|
||||
|
@ -10,25 +10,10 @@ from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
|
||||
)
|
||||
|
||||
|
||||
def python_function():
|
||||
return 0
|
||||
|
||||
|
||||
def rpc_return_rref(dst):
|
||||
return rpc.remote(dst, torch.add, args=(torch.ones(2, 2), 1))
|
||||
|
||||
|
||||
# Define Script functions on both client and server sides.
|
||||
@torch.jit.script
|
||||
def no_arg():
|
||||
return 0
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def one_arg(value):
|
||||
return value + 1
|
||||
|
||||
|
||||
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
|
||||
def __init__(self, dst_worker):
|
||||
super().__init__()
|
||||
@ -48,8 +33,13 @@ class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
|
||||
|
||||
@torch.jit.script
|
||||
class MyScriptClass:
|
||||
def __init__(self):
|
||||
self.a = 10
|
||||
def __init__(self, a):
|
||||
# type: (int)
|
||||
self.a = a
|
||||
|
||||
def get_value(self):
|
||||
# type: () -> int
|
||||
return self.a
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
@ -70,47 +60,94 @@ class MyScriptModule(torch.jit.ScriptModule):
|
||||
return self.a
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_to_here(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_var.to_here()
|
||||
def owner_create_rref_my_script_class(a):
|
||||
return rpc.RRef(MyScriptClass(a))
|
||||
|
||||
|
||||
def owner_create_rref_my_script_module(a):
|
||||
return rpc.RRef(MyScriptModule(a), MyModuleInterface)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def return_rref(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def my_script_module_init(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return MyScriptModule(rank)
|
||||
def script_run_get_value_rref_my_script_class(rref):
|
||||
# type: (RRef[MyScriptClass]) -> int
|
||||
return rref.to_here().get_value()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def construct_my_script_module(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return my_script_module_init(rank)
|
||||
def script_run_forward_rref_my_script_module(rref):
|
||||
# type: (RRef[MyModuleInterface]) -> Tensor
|
||||
return rref.to_here().forward()
|
||||
|
||||
|
||||
class LocalRRefTest(RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_create_local_script_class_rref_in_py(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Create a local RRef<MyScriptClass>.
|
||||
rref_script_class = rpc.RRef(MyScriptClass(self.rank, ))
|
||||
ret = rref_script_class.to_here().get_value()
|
||||
self.assertEqual(ret, self.rank)
|
||||
|
||||
@dist_init
|
||||
def test_create_local_script_module_rref_in_py(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Create a local RRef<MyModuleInterface>.
|
||||
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
|
||||
ret = rref_script_module.to_here().forward()
|
||||
self.assertEqual(ret, torch.ones(self.rank))
|
||||
|
||||
# Create a local RRef<MyModuleInterface> without type hint.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, (
|
||||
"The RRef being created contains a ScriptModule, "
|
||||
"must provide its ModuleInterface type hint."
|
||||
)
|
||||
):
|
||||
rref_script_module = rpc.RRef(MyScriptModule(self.rank))
|
||||
|
||||
@dist_init
|
||||
def test_return_local_script_class_rref_in_py_and_use_in_script(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
|
||||
|
||||
# Create a local RRef<MyScripClass> remotely in Python.
|
||||
rref = rpc.rpc_sync(dst_worker_name, owner_create_rref_my_script_class, args=(self.rank,))
|
||||
# Use RRef<MyScripClass> remotely in Script.
|
||||
ret = rpc.rpc_sync(
|
||||
rref.owner(), script_run_get_value_rref_my_script_class, args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret, self.rank)
|
||||
|
||||
@dist_init
|
||||
def test_return_local_script_module_rref_in_py_and_use_in_script(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
|
||||
|
||||
# Create a local RRef<MyModuleInterface> remotely in Python.
|
||||
rref = rpc.rpc_sync(dst_worker_name, owner_create_rref_my_script_module, args=(self.rank,))
|
||||
# Use RRef<MyModuleInterface> remotely in Script.
|
||||
ret = rpc.rpc_sync(
|
||||
rref.owner(), script_run_forward_rref_my_script_module, args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret, torch.ones(self.rank))
|
||||
|
||||
|
||||
def python_function():
|
||||
return 0
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def run_ref_script_module(ref_script_module, t):
|
||||
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
|
||||
module = ref_script_module.to_here()
|
||||
return module.forward() + t
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def rref_python_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_script_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_python_annotation(rref_var).to_here()
|
||||
def no_arg():
|
||||
return 0
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@ -450,10 +487,58 @@ class JitRpcAsyncOpTest:
|
||||
self.assertEqual(ret, 0)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def one_arg(value):
|
||||
return value + 1
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_to_here(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_var.to_here()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def return_rref(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def my_script_module_init(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return MyScriptModule(rank)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def construct_my_script_module(rank):
|
||||
# type: (int) -> MyModuleInterface
|
||||
return my_script_module_init(rank)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def run_ref_script_module(ref_script_module, t):
|
||||
# type: (RRef[MyModuleInterface], Tensor) -> Tensor
|
||||
module = ref_script_module.to_here()
|
||||
return module.forward() + t
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def rref_python_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> RRef[Tensor]
|
||||
return rref_var
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_script_annotation(rref_var):
|
||||
# type: (RRef[Tensor]) -> Tensor
|
||||
return rref_python_annotation(rref_var).to_here()
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._six.PY3, "Pytorch distributed rpc package does not support python2"
|
||||
)
|
||||
class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
|
||||
class JitRpcTest(LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixture):
|
||||
@dist_init
|
||||
def test_torchscript_function(self):
|
||||
dst_worker_name = "worker{}".format((self.rank + 1) % self.world_size)
|
||||
@ -484,7 +569,7 @@ class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
|
||||
# rpc_sync still accepts script class and run it in
|
||||
# the same code path as python call.
|
||||
ret = rpc.rpc_sync(
|
||||
dst_worker_name, MyScriptClass, args=()
|
||||
dst_worker_name, MyScriptClass, args=(self.rank,)
|
||||
)
|
||||
|
||||
# rpc_sync does not accept script module and script module method.
|
||||
@ -588,14 +673,3 @@ class JitRpcTest(JitRpcAsyncOpTest, RpcAgentTestFixture):
|
||||
|
||||
res = rref_script_annotation(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
|
||||
@dist_init
|
||||
def test_local_rref_creation_with_ivalue(self):
|
||||
|
||||
# create a local RRef that holds a IValue
|
||||
rref_local_script_class = rpc.RRef(MyScriptClass())
|
||||
self.assertEqual(rref_local_script_class.to_here().a, 10)
|
||||
|
||||
# create a local RRef that holds a ScriptModule
|
||||
rref_local_script_mod = rpc.RRef(MyScriptModule(3)._c)
|
||||
self.assertEqual(rref_local_script_mod.to_here().forward(), torch.ones(3))
|
||||
|
Reference in New Issue
Block a user