mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[rpc] Switch RRef to be managed by intrusive_ptr (#33189)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189 Add RRefInterface to Aten/Core, which will later be used by IValue Switch all the rpc code base to use intrusive_ptr instead of shared_ptr, so that we could add it to IValue. Actual adding to IValue and JIT will be in next PR Test Plan: Imported from OSS Differential Revision: D19871241 Pulled By: wanchaol fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
This commit is contained in:
committed by
Facebook Github Bot
parent
cb4e6d025a
commit
9ae4d38a21
@ -59,7 +59,7 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
||||
|
||||
/////////////////////////// PyRRef //////////////////////////////////
|
||||
|
||||
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
||||
}
|
||||
|
||||
@ -87,18 +87,15 @@ py::object PyRRef::toHere() {
|
||||
} else {
|
||||
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
|
||||
// a python object
|
||||
std::vector<IValue> rawValues =
|
||||
std::static_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
IValue value;
|
||||
IValue value =
|
||||
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
if (rref_->isPyObj()) {
|
||||
value = jit::toIValue(
|
||||
PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(std::move(rawValues))),
|
||||
PyObjectType::get());
|
||||
// python_rpc_handler deserialization will acquires GIL.
|
||||
auto rfr_values = value.toTuple()->elements();
|
||||
return PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(rfr_values)
|
||||
);
|
||||
} else {
|
||||
value = std::move(rawValues).front();
|
||||
}
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
||||
// without grabbing the GIL.
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
@ -114,7 +111,7 @@ py::object PyRRef::localValue() {
|
||||
owner().name_);
|
||||
|
||||
py::object res;
|
||||
auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
|
||||
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
|
||||
auto& rpcHandler = PythonRpcHandler::getInstance();
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object without
|
||||
@ -131,8 +128,8 @@ std::string PyRRef::str() const {
|
||||
if (rref_->isOwner()) {
|
||||
ss << "OwnerRRef(" << rref_->rrefId() << ")";
|
||||
} else {
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId()
|
||||
<< ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = "
|
||||
<< c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId()
|
||||
<< ")";
|
||||
}
|
||||
return ss.str();
|
||||
@ -151,10 +148,9 @@ py::tuple PyRRef::pickle() const {
|
||||
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
auto rrefForkData = fromPyTuple(pyTuple);
|
||||
std::shared_ptr<RRef> rref = nullptr;
|
||||
TypePtr rrefType =
|
||||
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
|
||||
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
||||
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
||||
ctx.notifyOwnerAndParentOfFork(
|
||||
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
||||
return PyRRef(std::move(rref));
|
||||
|
Reference in New Issue
Block a user