[rpc] Switch RRef to be managed by intrusive_ptr (#33189)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189

Add RRefInterface to Aten/Core, which will later be used by IValue

Switch all the rpc code base to use intrusive_ptr instead of shared_ptr,
so that we could add it to IValue.

Actual adding to IValue and JIT will be in next PR

Test Plan: Imported from OSS

Differential Revision: D19871241

Pulled By: wanchaol

fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
This commit is contained in:
Wanchao Liang
2020-02-13 20:13:10 -08:00
committed by Facebook Github Bot
parent cb4e6d025a
commit 9ae4d38a21
10 changed files with 96 additions and 104 deletions

View File

@ -1,14 +1,16 @@
#pragma once
#include <torch/csrc/distributed/rpc/types.h>
#include <c10/util/intrusive_ptr.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace c10 {
struct Type;
using TypePtr = std::shared_ptr<Type>;
using worker_id_t = int16_t;
// This abstract class contains only user-facing APIs, and will be shared
// between jit and distributed to implement TorchScript support.
class RRefInterface {
class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
public:
RRefInterface() = default;
// RRef is made NOT copyable NOT movable to prevent messing up reference
@ -24,8 +26,8 @@ class RRefInterface {
// Returns true if this is the ``OwnerRRef``
virtual bool isOwner() const = 0;
virtual const TypePtr type() const = 0;
};
} // namespace rpc
} // namespace distributed
} // namespace torch
}

View File

@ -59,7 +59,7 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
@ -87,18 +87,15 @@ py::object PyRRef::toHere() {
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
std::vector<IValue> rawValues =
std::static_pointer_cast<UserRRef>(rref_)->toHere();
IValue value;
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
if (rref_->isPyObj()) {
value = jit::toIValue(
PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(std::move(rawValues))),
PyObjectType::get());
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr_values)
);
} else {
value = std::move(rawValues).front();
}
{
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
@ -114,7 +111,7 @@ py::object PyRRef::localValue() {
owner().name_);
py::object res;
auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
auto& rpcHandler = PythonRpcHandler::getInstance();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
@ -131,8 +128,8 @@ std::string PyRRef::str() const {
if (rref_->isOwner()) {
ss << "OwnerRRef(" << rref_->rrefId() << ")";
} else {
ss << "UserRRef(RRefId = " << rref_->rrefId()
<< ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = "
<< c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId()
<< ")";
}
return ss.str();
@ -151,10 +148,9 @@ py::tuple PyRRef::pickle() const {
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = fromPyTuple(pyTuple);
std::shared_ptr<RRef> rref = nullptr;
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));

View File

@ -12,9 +12,8 @@ namespace rpc {
// pickle and unpickle.
class PyRRef {
public:
explicit PyRRef(std::shared_ptr<RRef> rref);
// creates a local RRef with the given object as value
explicit PyRRef(const py::object& value);
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
bool isOwner() const;
WorkerInfo owner() const;
@ -25,7 +24,7 @@ class PyRRef {
static PyRRef unpickle(const py::tuple& t);
private:
std::shared_ptr<RRef> rref_;
c10::intrusive_ptr<RRef> rref_;
};
} // namespace rpc

View File

@ -156,7 +156,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
case MessageType::SCRIPT_RREF_FETCH_CALL: {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
c10::intrusive_ptr<OwnerRRef> rref =
ctx.getOwnerRRef(srf.rrefId());
if (rref->hasValue()) { // optional fast-path
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
}
@ -181,7 +182,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
case MessageType::PYTHON_RREF_FETCH_CALL: {
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(prf.rrefId());
c10::intrusive_ptr<OwnerRRef> rref =
ctx.getOwnerRRef(prf.rrefId());
if (rref->hasValue()) { // optional fast-path
auto value = rref->getValue();
py::object pyValue;

View File

@ -28,7 +28,7 @@ RRefContext& RRefContext::getInstance() {
return *context;
}
std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance(
std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
bool ignoreRRefLeak) {
auto& ctx = RRefContext::getInstance();
{
@ -36,7 +36,7 @@ std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance(
ctx.destroyed_ = true;
}
ctx.checkRRefLeaks(ignoreRRefLeak);
std::vector<std::shared_ptr<RRef>> deletedRRefs;
std::vector<c10::intrusive_ptr<RRef>> deletedRRefs;
for (auto& entry : ctx.owners_) {
auto rref = entry.second;
if (rref->isPyObj()) {
@ -105,9 +105,7 @@ void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
}
}
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
worker_id_t ownerId,
const TypePtr& type) {
c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(worker_id_t ownerId, const TypePtr& type) {
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
// Explicitly creating rrefId before forkId to make sure the order is
// deterministic, as the argument evaluation order is system and compiler
@ -117,7 +115,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef(
return createUserRRef(ownerId, rrefId, forkId, type);
}
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId,
@ -136,7 +134,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef(
// The reason for not adding the pending user here is to put addPendingUser()
// close to where the RPC occurs, and it is more clear to pair it with
// deletePendingUser() in the response callback at the call site.
return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId, type));
return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type);
}
void RRefContext::delUser(
@ -156,7 +154,7 @@ void RRefContext::delUser(
}
}
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
const RRefForkData& rrefForkData,
const TypePtr& type) {
auto& ownerId = rrefForkData.ownerId_;
@ -171,7 +169,7 @@ std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
}
}
std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
const RRefId& rrefId,
const TypePtr& type) {
std::lock_guard<std::mutex> lock(mutex_);
@ -182,41 +180,40 @@ std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
// NB: cannot use make_shared here as the constructor of OwnerRRef is
// private.
auto rref =
std::shared_ptr<OwnerRRef>(new OwnerRRef(getWorkerId(), rrefId, type));
c10::make_intrusive<OwnerRRef>(getWorkerId(), rrefId, type);
owners_[rref->rrefId()] = rref;
ownerCV_.notify_all();
return rref;
} else {
// Scenario (2) retrieving an existing RRef
auto ownerRRef = std::static_pointer_cast<OwnerRRef>(iter->second);
auto ownerRRef = c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second);
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
return ownerRRef;
}
}
std::shared_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
// Don't add this OnwerRRef to the owners_ map yet, otherwise
// it will never be removed from there. Instead, only add it to the
// map in prepareChildFork, in case this local RRef is being passed
// to another worker.
return std::shared_ptr<OwnerRRef>(
new OwnerRRef(getWorkerId(), genGloballyUniqueId(), type));
return c10::make_intrusive<OwnerRRef>(getWorkerId(), genGloballyUniqueId(), type);
}
std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
c10::intrusive_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
std::unique_lock<std::mutex> lock(mutex_);
const auto iter = owners_.find(rrefId);
if (iter == owners_.end()) {
// Scenario (1) RRef is used before it is created
ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); });
return std::static_pointer_cast<OwnerRRef>(owners_[rrefId]);
return c10::static_intrusive_pointer_cast<OwnerRRef>(owners_[rrefId]);
} else {
// Scenario (2) retrieving an existing RRef
return std::static_pointer_cast<OwnerRRef>(iter->second);
return c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second);
}
}
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
RRefForkData RRefContext::prepareChildFork(const c10::intrusive_ptr<RRef>& rref) {
auto rrefForkData = rref->fork();
if (rref->isOwner()) {
// Note [Early Fork Registration]
@ -256,7 +253,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
void RRefContext::notifyOwnerAndParentOfFork(
const ForkId& forkId,
worker_id_t parent,
const std::shared_ptr<RRef>& rref) {
const c10::intrusive_ptr<RRef>& rref) {
if (parent == rref->owner()) {
if (parent == agent_->getWorkerInfo().id_) {
// Owner sending RRef to self, remove the forkId as it was added during
@ -310,7 +307,7 @@ void RRefContext::notifyOwnerAndParentOfFork(
void RRefContext::addPendingChild(
const ForkId& forkId,
const std::shared_ptr<RRef>& rref) {
const c10::intrusive_ptr<RRef>& rref) {
// see Note [Early Fork Registration]
// If the parent is the owner, it should directly add the child UserRRef as a
// fork.
@ -334,7 +331,9 @@ void RRefContext::delPendingChild(const ForkId& forkId) {
void RRefContext::addPendingUser(
const ForkId& forkId,
const std::shared_ptr<RRef>& rref) {
const c10::intrusive_ptr<RRef>& rref) {
TORCH_INTERNAL_ASSERT(
!rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
std::lock_guard<std::mutex> lock(mutex_);
TORCH_INTERNAL_ASSERT(
pendingUsers_.find(forkId) == pendingUsers_.end(),
@ -362,7 +361,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
});
}
void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef>& rref) {
void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
std::lock_guard<std::mutex> lock(mutex_);
const auto& rrefId = rref->rrefId();
owners_[rrefId] = rref;
@ -384,10 +383,10 @@ void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
rrefForks.insert(forkId);
}
std::shared_ptr<RRef> RRefContext::delForkOfOwner(
c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner(
const RRefId& rrefId,
const ForkId& forkId) {
std::shared_ptr<RRef> deletedRRef = nullptr;
c10::intrusive_ptr<RRef> deletedRRef;
{
std::lock_guard<std::mutex> lock(mutex_);
auto rrefIter = forks_.find(rrefId);

View File

@ -28,7 +28,7 @@ class TORCH_API RRefContext {
// hold py::object. The call-site is also responsible for resetting those
// shared_ptr objects with a GIL. See comments at delForkOfOwner() for more
// details.
static std::vector<std::shared_ptr<RRef>> destroyInstance(
static std::vector<c10::intrusive_ptr<RRef>> destroyInstance(
bool ignoreRRefLeak = true);
static void handleException(const c10::optional<utils::FutureError>& futErr);
@ -60,27 +60,21 @@ class TORCH_API RRefContext {
}
// create a ``UserRRef`` owned by the worker ``ownerId``
std::shared_ptr<UserRRef> createUserRRef(
worker_id_t ownerId,
const TypePtr& type);
c10::intrusive_ptr<UserRRef> createUserRRef(worker_id_t ownerId, const TypePtr& type);
// Convert an RRefForkData into an RRef. This RRef could be user or owner.
// This RRef could have already existed before, or could be created in this
// method, we pass type here to validate or help the rref creation.
std::shared_ptr<RRef> getOrCreateRRef(
const RRefForkData& rfd,
const TypePtr& type);
c10::intrusive_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd, const TypePtr& type);
// Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
// one.
std::shared_ptr<OwnerRRef> getOrCreateOwnerRRef(
const RRefId& rrefId,
const TypePtr& type);
c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef(const RRefId& rrefId, const TypePtr& type);
// Create an empty owner rref of type.
std::shared_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
std::shared_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
c10::intrusive_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
// Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
// making a remote call to self, which as for now, still goes through serde
@ -92,9 +86,9 @@ class TORCH_API RRefContext {
// and this could happen before the self remote call finishes. To prevent
// that, this API adds the RRefId as a ForkId, which will then delete the
// ForkId when the self remote is done.
void addSelfAsFork(std::shared_ptr<OwnerRRef>& rref);
void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref);
// Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the
// Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the
// ``OwnerRRef`` in a map to keep it alive.
void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
// Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the
@ -106,19 +100,19 @@ class TORCH_API RRefContext {
// py::object, deleting it require GIL. The call site should guarded it with
// a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally
// left out of this function to avoid creating dependency on pybind.
std::shared_ptr<RRef> delForkOfOwner(
c10::intrusive_ptr<RRef> delForkOfOwner(
const RRefId& rrefId,
const ForkId& forkId);
// Invoked when pickling an RRef to setup child/fork properly
RRefForkData prepareChildFork(const std::shared_ptr<RRef>& rref);
RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref);
// Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and
// send RREF_CHILD_ACCEPT to the parent.
// NB: forkId is necessary here as the rref could be an OwnerRRef
void notifyOwnerAndParentOfFork(
const ForkId& forkId,
worker_id_t parent,
const std::shared_ptr<RRef>& rref);
const c10::intrusive_ptr<RRef>& rref);
// When a UserRRef is forked to another worker (user or owner), it is added
// into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT
@ -128,12 +122,12 @@ class TORCH_API RRefContext {
// previously submitted rpc/remote calls are acked before sending out the
// RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
// soon.
void addPendingChild(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
void addPendingChild(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref);
void delPendingChild(const ForkId& forkId);
// When a UserRRef is created, it is added into pendingUsers_ to be held alive
// until it receives RREF_USER_ACCEPT from the owner.
void addPendingUser(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
void addPendingUser(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref);
void delPendingUser(const ForkId& forkId);
void delUser(
@ -146,7 +140,7 @@ class TORCH_API RRefContext {
private:
RRefContext(std::shared_ptr<RpcAgent>);
std::shared_ptr<UserRRef> createUserRRef(
c10::intrusive_ptr<UserRRef> createUserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId,
@ -162,7 +156,7 @@ class TORCH_API RRefContext {
const std::shared_ptr<RpcAgent> agent_;
mutable std::mutex mutex_;
// Keep OwnerRRefs alive until there is no living UserRRefs.
std::unordered_map<RRefId, std::shared_ptr<RRef>, RRefId::Hash> owners_;
std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_;
// A conditional variable to block getOwnerRRef() calls until the
// corresponding OwnerRRef has been created and inserted into the owners_ map.
// The method getOwnerRRef() is triggered by rref.to_here() messages. The
@ -184,7 +178,7 @@ class TORCH_API RRefContext {
RRefId::Hash>
forks_;
// The follow two maps keep UserRRefs alive by holding a shared_ptr to the
// The follow two maps keep UserRRefs alive by holding a intrusive_ptr to the
// RRef instances. A UserRRef must be added into this map if any of the
// following two conditions is true:
//
@ -193,7 +187,7 @@ class TORCH_API RRefContext {
// It can be used or shared, but cannot be deleted, and hence kept alive
// in this map. A message of type RREF_USER_ACCEPT will remove the
// corresponding RRef from this map.
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash> pendingUsers_;
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> pendingUsers_;
// (2) A UserRRef has forked a child UserRRef which has not been accepted by
// the owner yet.
@ -201,7 +195,7 @@ class TORCH_API RRefContext {
// In this case, this UserRRef cannot send out RREF_USER_DELETE message,
// as it could potentially trigger the OwnerRRef been deleted before the
// owner learns about the forked child.
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash>
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
pendingChildren_;
std::mutex destroyedMutex_;

View File

@ -77,7 +77,7 @@ const ForkId& UserRRef::forkId() const {
return forkId_;
}
std::vector<IValue> UserRRef::toHere() {
IValue UserRRef::toHere() {
auto agent = RpcAgent::getCurrentRpcAgent();
// ScriptRRefFetchCall message always carries autograd context id even if
@ -107,7 +107,13 @@ std::vector<IValue> UserRRef::toHere() {
"or PYTHON_RREF_FETCH_RET");
RpcCommandBase& rpc = *response;
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
return rrefFetchRet.values();
if (isPyObj()) {
// wrap python serialized vector of ivalues into tuple, this
// made the C++ toHere interface to return single IValue
return ivalue::Tuple::create(rrefFetchRet.values());
} else {
return rrefFetchRet.values().front();
}
}
////////////////////////// OwnerRRef /////////////////////////////////////

View File

@ -1,10 +1,10 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <ATen/core/rref_interface.h>
#include <c10/util/Optional.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_interface.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <atomic>
@ -27,12 +27,13 @@ struct TORCH_API RRefForkData {
RRefForkData(
worker_id_t ownerId,
const RRefId& rrefId_,
const ForkId& forkId_,
const RRefId& rrefId,
const ForkId& forkId,
worker_id_t parent,
std::string typeStr);
};
// Note [RRef Protocol]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
//
@ -198,7 +199,7 @@ class TORCH_API RRef : public RRefInterface {
inline bool isPyObj() {
return type_ == PyObjectType::get();
}
inline const TypePtr type() {
inline const TypePtr type() const override{
return type_;
}
@ -228,6 +229,8 @@ class TORCH_API UserRRef final : public RRef {
UserRRef& operator=(const UserRRef& other) = delete;
UserRRef& operator=(UserRRef&& other) = delete;
UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId, TypePtr type);
inline bool isOwner() const override {
return false;
}
@ -237,7 +240,7 @@ class TORCH_API UserRRef final : public RRef {
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
// yet, this call will block.
std::vector<IValue> toHere();
IValue toHere();
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
~UserRRef() override;
@ -245,12 +248,6 @@ class TORCH_API UserRRef final : public RRef {
private:
friend class RRefContext;
UserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId,
TypePtr type);
const ForkId forkId_;
};
@ -263,6 +260,15 @@ class TORCH_API OwnerRRef final : public RRef {
OwnerRRef& operator=(const OwnerRRef& other) = delete;
OwnerRRef& operator=(OwnerRRef&& other) = delete;
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
: OwnerRRef(ownerId, rrefId, type, {}) {}
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type, c10::optional<IValue> value)
: RRef(ownerId, rrefId, std::move(type)) {
value_ = std::move(value);
}
inline bool isOwner() const override {
return true;
}
@ -284,18 +290,6 @@ class TORCH_API OwnerRRef final : public RRef {
private:
friend class RRefContext;
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
: OwnerRRef(ownerId, rrefId, type, {}) {}
OwnerRRef(
worker_id_t ownerId,
const RRefId& rrefId,
TypePtr type,
c10::optional<IValue> value)
: RRef(ownerId, rrefId, std::move(type)) {
value_ = std::move(value);
}
c10::optional<IValue> value_;
mutable std::mutex mutex_;
mutable std::condition_variable valueCV_;

View File

@ -49,7 +49,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
return futPtr;
}
std::shared_ptr<UserRRef> remoteTorchscript(
c10::intrusive_ptr<UserRRef> remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,

View File

@ -25,7 +25,7 @@ c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack);
std::shared_ptr<UserRRef> TORCH_API remoteTorchscript(
c10::intrusive_ptr<UserRRef> TORCH_API remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,