mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] allow RRef local creation with IValue objects (#33263)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33263 This PR allow PyRRef local creation to inspect the pyobject, if it founds that we could turn it to an IValue, turn to an IValue first, otherwise hold it as a PyObjectType Test Plan: Imported from OSS https://fb.quip.com/aGxRAh2lCg05 Differential Revision: D19871243 Pulled By: wanchaol fbshipit-source-id: ae5be3c52fb1e6db33c64e64ef64bc8b9ea63a9a
This commit is contained in:
committed by
Facebook Github Bot
parent
1507573a52
commit
64aab3260a
@ -56,11 +56,22 @@ PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
|
||||
PyRRef::PyRRef(const py::object& value)
|
||||
: PyRRef([&value]() {
|
||||
jit::InferredType type_inferred = jit::tryToInferType(value);
|
||||
TypePtr elem_type = nullptr;
|
||||
if (type_inferred.success()) {
|
||||
// If we could infer the type from the pyobject, we create
|
||||
// the RRef with the IValue of that type.
|
||||
elem_type = type_inferred.type();
|
||||
} else {
|
||||
// Otherwise it's a pure pyobject, create the RRef
|
||||
// that holds an IValue of an pyobject
|
||||
elem_type = PyObjectType::get();
|
||||
}
|
||||
auto rref =
|
||||
RRefContext::getInstance().createOwnerRRef(PyObjectType::get());
|
||||
RRefContext::getInstance().createOwnerRRef(elem_type);
|
||||
py::object copy(value); // increases refcount
|
||||
IValue py_ivalue = jit::toIValue(std::move(copy), PyObjectType::get());
|
||||
rref->setValue(std::move(py_ivalue));
|
||||
IValue ivalue = jit::toIValue(std::move(copy), elem_type);
|
||||
rref->setValue(std::move(ivalue));
|
||||
return rref;
|
||||
}()) {}
|
||||
|
||||
|
@ -223,3 +223,14 @@ class JitRpcTest(RpcAgentTestFixture):
|
||||
|
||||
res = rref_script_annotation(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
|
||||
@dist_init
|
||||
def test_local_rref_creation_with_ivalue(self):
|
||||
|
||||
# create a local RRef that holds a IValue
|
||||
rref_local_script_class = rpc.RRef(MyScriptClass())
|
||||
self.assertEqual(rref_local_script_class.to_here().a, 10)
|
||||
|
||||
# create a local RRef that holds a ScriptModule
|
||||
rref_local_script_mod = rpc.RRef(MyScriptModule(3)._c)
|
||||
self.assertEqual(rref_local_script_mod.to_here().forward(), torch.ones(3))
|
||||
|
Reference in New Issue
Block a user