[jit] allow RRef local creation with IValue objects (#33263)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33263

This PR allow PyRRef local creation to inspect the pyobject, if it
founds that we could turn it to an IValue, turn to an IValue first,
otherwise hold it as a PyObjectType

Test Plan:
Imported from OSS

https://fb.quip.com/aGxRAh2lCg05

Differential Revision: D19871243

Pulled By: wanchaol

fbshipit-source-id: ae5be3c52fb1e6db33c64e64ef64bc8b9ea63a9a
This commit is contained in:
Wanchao Liang
2020-02-27 22:47:33 -08:00
committed by Facebook Github Bot
parent 1507573a52
commit 64aab3260a
2 changed files with 25 additions and 3 deletions

View File

@ -56,11 +56,22 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(const py::object& value)
: PyRRef([&value]() {
jit::InferredType type_inferred = jit::tryToInferType(value);
TypePtr elem_type = nullptr;
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
elem_type = type_inferred.type();
} else {
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject
elem_type = PyObjectType::get();
}
auto rref =
RRefContext::getInstance().createOwnerRRef(PyObjectType::get());
RRefContext::getInstance().createOwnerRRef(elem_type);
py::object copy(value); // increases refcount
IValue py_ivalue = jit::toIValue(std::move(copy), PyObjectType::get());
rref->setValue(std::move(py_ivalue));
IValue ivalue = jit::toIValue(std::move(copy), elem_type);
rref->setValue(std::move(ivalue));
return rref;
}()) {}

View File

@ -223,3 +223,14 @@ class JitRpcTest(RpcAgentTestFixture):
res = rref_script_annotation(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
@dist_init
def test_local_rref_creation_with_ivalue(self):
# create a local RRef that holds a IValue
rref_local_script_class = rpc.RRef(MyScriptClass())
self.assertEqual(rref_local_script_class.to_here().a, 10)
# create a local RRef that holds a ScriptModule
rref_local_script_mod = rpc.RRef(MyScriptModule(3)._c)
self.assertEqual(rref_local_script_mod.to_here().forward(), torch.ones(3))