[rpc] Switch RRef to be managed by intrusive_ptr (#33189)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189

Add RRefInterface to Aten/Core, which will later be used by IValue

Switch all the rpc code base to use intrusive_ptr instead of shared_ptr,
so that we could add it to IValue.

Actual adding to IValue and JIT will be in next PR

Test Plan: Imported from OSS

Differential Revision: D19871241

Pulled By: wanchaol

fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
This commit is contained in:
Wanchao Liang
2020-02-13 20:13:10 -08:00
committed by Facebook Github Bot
parent cb4e6d025a
commit 9ae4d38a21
10 changed files with 96 additions and 104 deletions

View File

@ -59,7 +59,7 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
@ -87,18 +87,15 @@ py::object PyRRef::toHere() {
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
std::vector<IValue> rawValues =
std::static_pointer_cast<UserRRef>(rref_)->toHere();
IValue value;
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
if (rref_->isPyObj()) {
value = jit::toIValue(
PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(std::move(rawValues))),
PyObjectType::get());
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr_values)
);
} else {
value = std::move(rawValues).front();
}
{
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
@ -114,7 +111,7 @@ py::object PyRRef::localValue() {
owner().name_);
py::object res;
auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
auto& rpcHandler = PythonRpcHandler::getInstance();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
@ -131,8 +128,8 @@ std::string PyRRef::str() const {
if (rref_->isOwner()) {
ss << "OwnerRRef(" << rref_->rrefId() << ")";
} else {
ss << "UserRRef(RRefId = " << rref_->rrefId()
<< ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = "
<< c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId()
<< ")";
}
return ss.str();
@ -151,10 +148,9 @@ py::tuple PyRRef::pickle() const {
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = fromPyTuple(pyTuple);
std::shared_ptr<RRef> rref = nullptr;
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));