mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[rpc] Switch RRef to be managed by intrusive_ptr (#33189)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189 Add RRefInterface to Aten/Core, which will later be used by IValue Switch all the rpc code base to use intrusive_ptr instead of shared_ptr, so that we could add it to IValue. Actual adding to IValue and JIT will be in next PR Test Plan: Imported from OSS Differential Revision: D19871241 Pulled By: wanchaol fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
This commit is contained in:
committed by
Facebook Github Bot
parent
cb4e6d025a
commit
9ae4d38a21
@ -1,14 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
namespace c10 {
|
||||
|
||||
struct Type;
|
||||
using TypePtr = std::shared_ptr<Type>;
|
||||
using worker_id_t = int16_t;
|
||||
|
||||
// This abstract class contains only user-facing APIs, and will be shared
|
||||
// between jit and distributed to implement TorchScript support.
|
||||
class RRefInterface {
|
||||
class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
|
||||
public:
|
||||
RRefInterface() = default;
|
||||
// RRef is made NOT copyable NOT movable to prevent messing up reference
|
||||
@ -24,8 +26,8 @@ class RRefInterface {
|
||||
|
||||
// Returns true if this is the ``OwnerRRef``
|
||||
virtual bool isOwner() const = 0;
|
||||
|
||||
virtual const TypePtr type() const = 0;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
}
|
@ -59,7 +59,7 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
||||
|
||||
/////////////////////////// PyRRef //////////////////////////////////
|
||||
|
||||
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
||||
}
|
||||
|
||||
@ -87,18 +87,15 @@ py::object PyRRef::toHere() {
|
||||
} else {
|
||||
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
|
||||
// a python object
|
||||
std::vector<IValue> rawValues =
|
||||
std::static_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
IValue value;
|
||||
IValue value =
|
||||
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
if (rref_->isPyObj()) {
|
||||
value = jit::toIValue(
|
||||
PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(std::move(rawValues))),
|
||||
PyObjectType::get());
|
||||
// python_rpc_handler deserialization will acquires GIL.
|
||||
auto rfr_values = value.toTuple()->elements();
|
||||
return PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(rfr_values)
|
||||
);
|
||||
} else {
|
||||
value = std::move(rawValues).front();
|
||||
}
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
||||
// without grabbing the GIL.
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
@ -114,7 +111,7 @@ py::object PyRRef::localValue() {
|
||||
owner().name_);
|
||||
|
||||
py::object res;
|
||||
auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
|
||||
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
|
||||
auto& rpcHandler = PythonRpcHandler::getInstance();
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object without
|
||||
@ -131,8 +128,8 @@ std::string PyRRef::str() const {
|
||||
if (rref_->isOwner()) {
|
||||
ss << "OwnerRRef(" << rref_->rrefId() << ")";
|
||||
} else {
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId()
|
||||
<< ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = "
|
||||
<< c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId()
|
||||
<< ")";
|
||||
}
|
||||
return ss.str();
|
||||
@ -151,10 +148,9 @@ py::tuple PyRRef::pickle() const {
|
||||
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
auto rrefForkData = fromPyTuple(pyTuple);
|
||||
std::shared_ptr<RRef> rref = nullptr;
|
||||
TypePtr rrefType =
|
||||
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
|
||||
rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
||||
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
||||
ctx.notifyOwnerAndParentOfFork(
|
||||
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
||||
return PyRRef(std::move(rref));
|
||||
|
@ -12,9 +12,8 @@ namespace rpc {
|
||||
// pickle and unpickle.
|
||||
class PyRRef {
|
||||
public:
|
||||
explicit PyRRef(std::shared_ptr<RRef> rref);
|
||||
// creates a local RRef with the given object as value
|
||||
explicit PyRRef(const py::object& value);
|
||||
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
|
||||
|
||||
bool isOwner() const;
|
||||
WorkerInfo owner() const;
|
||||
@ -25,7 +24,7 @@ class PyRRef {
|
||||
static PyRRef unpickle(const py::tuple& t);
|
||||
|
||||
private:
|
||||
std::shared_ptr<RRef> rref_;
|
||||
c10::intrusive_ptr<RRef> rref_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -156,7 +156,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
||||
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
|
||||
c10::intrusive_ptr<OwnerRRef> rref =
|
||||
ctx.getOwnerRRef(srf.rrefId());
|
||||
if (rref->hasValue()) { // optional fast-path
|
||||
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
|
||||
}
|
||||
@ -181,7 +182,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
||||
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(prf.rrefId());
|
||||
c10::intrusive_ptr<OwnerRRef> rref =
|
||||
ctx.getOwnerRRef(prf.rrefId());
|
||||
if (rref->hasValue()) { // optional fast-path
|
||||
auto value = rref->getValue();
|
||||
py::object pyValue;
|
||||
|
@ -28,7 +28,7 @@ RRefContext& RRefContext::getInstance() {
|
||||
return *context;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance(
|
||||
std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
|
||||
bool ignoreRRefLeak) {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
{
|
||||
@ -36,7 +36,7 @@ std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance(
|
||||
ctx.destroyed_ = true;
|
||||
}
|
||||
ctx.checkRRefLeaks(ignoreRRefLeak);
|
||||
std::vector<std::shared_ptr<RRef>> deletedRRefs;
|
||||
std::vector<c10::intrusive_ptr<RRef>> deletedRRefs;
|
||||
for (auto& entry : ctx.owners_) {
|
||||
auto rref = entry.second;
|
||||
if (rref->isPyObj()) {
|
||||
@ -105,9 +105,7 @@ void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const TypePtr& type) {
|
||||
c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(worker_id_t ownerId, const TypePtr& type) {
|
||||
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
|
||||
// Explicitly creating rrefId before forkId to make sure the order is
|
||||
// deterministic, as the argument evaluation order is system and compiler
|
||||
@ -117,7 +115,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
return createUserRRef(ownerId, rrefId, forkId, type);
|
||||
}
|
||||
|
||||
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
@ -136,7 +134,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
// The reason for not adding the pending user here is to put addPendingUser()
|
||||
// close to where the RPC occurs, and it is more clear to pair it with
|
||||
// deletePendingUser() in the response callback at the call site.
|
||||
return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId, type));
|
||||
return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type);
|
||||
}
|
||||
|
||||
void RRefContext::delUser(
|
||||
@ -156,7 +154,7 @@ void RRefContext::delUser(
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
|
||||
c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
|
||||
const RRefForkData& rrefForkData,
|
||||
const TypePtr& type) {
|
||||
auto& ownerId = rrefForkData.ownerId_;
|
||||
@ -171,7 +169,7 @@ std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
|
||||
c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const TypePtr& type) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
@ -182,41 +180,40 @@ std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
|
||||
// NB: cannot use make_shared here as the constructor of OwnerRRef is
|
||||
// private.
|
||||
auto rref =
|
||||
std::shared_ptr<OwnerRRef>(new OwnerRRef(getWorkerId(), rrefId, type));
|
||||
c10::make_intrusive<OwnerRRef>(getWorkerId(), rrefId, type);
|
||||
owners_[rref->rrefId()] = rref;
|
||||
ownerCV_.notify_all();
|
||||
return rref;
|
||||
} else {
|
||||
// Scenario (2) retrieving an existing RRef
|
||||
auto ownerRRef = std::static_pointer_cast<OwnerRRef>(iter->second);
|
||||
auto ownerRRef = c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second);
|
||||
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
|
||||
return ownerRRef;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
|
||||
c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
|
||||
// Don't add this OnwerRRef to the owners_ map yet, otherwise
|
||||
// it will never be removed from there. Instead, only add it to the
|
||||
// map in prepareChildFork, in case this local RRef is being passed
|
||||
// to another worker.
|
||||
return std::shared_ptr<OwnerRRef>(
|
||||
new OwnerRRef(getWorkerId(), genGloballyUniqueId(), type));
|
||||
return c10::make_intrusive<OwnerRRef>(getWorkerId(), genGloballyUniqueId(), type);
|
||||
}
|
||||
|
||||
std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
|
||||
c10::intrusive_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
const auto iter = owners_.find(rrefId);
|
||||
if (iter == owners_.end()) {
|
||||
// Scenario (1) RRef is used before it is created
|
||||
ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); });
|
||||
return std::static_pointer_cast<OwnerRRef>(owners_[rrefId]);
|
||||
return c10::static_intrusive_pointer_cast<OwnerRRef>(owners_[rrefId]);
|
||||
} else {
|
||||
// Scenario (2) retrieving an existing RRef
|
||||
return std::static_pointer_cast<OwnerRRef>(iter->second);
|
||||
return c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second);
|
||||
}
|
||||
}
|
||||
|
||||
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
RRefForkData RRefContext::prepareChildFork(const c10::intrusive_ptr<RRef>& rref) {
|
||||
auto rrefForkData = rref->fork();
|
||||
if (rref->isOwner()) {
|
||||
// Note [Early Fork Registration]
|
||||
@ -256,7 +253,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
void RRefContext::notifyOwnerAndParentOfFork(
|
||||
const ForkId& forkId,
|
||||
worker_id_t parent,
|
||||
const std::shared_ptr<RRef>& rref) {
|
||||
const c10::intrusive_ptr<RRef>& rref) {
|
||||
if (parent == rref->owner()) {
|
||||
if (parent == agent_->getWorkerInfo().id_) {
|
||||
// Owner sending RRef to self, remove the forkId as it was added during
|
||||
@ -310,7 +307,7 @@ void RRefContext::notifyOwnerAndParentOfFork(
|
||||
|
||||
void RRefContext::addPendingChild(
|
||||
const ForkId& forkId,
|
||||
const std::shared_ptr<RRef>& rref) {
|
||||
const c10::intrusive_ptr<RRef>& rref) {
|
||||
// see Note [Early Fork Registration]
|
||||
// If the parent is the owner, it should directly add the child UserRRef as a
|
||||
// fork.
|
||||
@ -334,7 +331,9 @@ void RRefContext::delPendingChild(const ForkId& forkId) {
|
||||
|
||||
void RRefContext::addPendingUser(
|
||||
const ForkId& forkId,
|
||||
const std::shared_ptr<RRef>& rref) {
|
||||
const c10::intrusive_ptr<RRef>& rref) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pendingUsers_.find(forkId) == pendingUsers_.end(),
|
||||
@ -362,7 +361,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
|
||||
});
|
||||
}
|
||||
|
||||
void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef>& rref) {
|
||||
void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto& rrefId = rref->rrefId();
|
||||
owners_[rrefId] = rref;
|
||||
@ -384,10 +383,10 @@ void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
|
||||
rrefForks.insert(forkId);
|
||||
}
|
||||
|
||||
std::shared_ptr<RRef> RRefContext::delForkOfOwner(
|
||||
c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner(
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId) {
|
||||
std::shared_ptr<RRef> deletedRRef = nullptr;
|
||||
c10::intrusive_ptr<RRef> deletedRRef;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto rrefIter = forks_.find(rrefId);
|
||||
|
@ -28,7 +28,7 @@ class TORCH_API RRefContext {
|
||||
// hold py::object. The call-site is also responsible for resetting those
|
||||
// shared_ptr objects with a GIL. See comments at delForkOfOwner() for more
|
||||
// details.
|
||||
static std::vector<std::shared_ptr<RRef>> destroyInstance(
|
||||
static std::vector<c10::intrusive_ptr<RRef>> destroyInstance(
|
||||
bool ignoreRRefLeak = true);
|
||||
|
||||
static void handleException(const c10::optional<utils::FutureError>& futErr);
|
||||
@ -60,27 +60,21 @@ class TORCH_API RRefContext {
|
||||
}
|
||||
|
||||
// create a ``UserRRef`` owned by the worker ``ownerId``
|
||||
std::shared_ptr<UserRRef> createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const TypePtr& type);
|
||||
c10::intrusive_ptr<UserRRef> createUserRRef(worker_id_t ownerId, const TypePtr& type);
|
||||
|
||||
// Convert an RRefForkData into an RRef. This RRef could be user or owner.
|
||||
// This RRef could have already existed before, or could be created in this
|
||||
// method, we pass type here to validate or help the rref creation.
|
||||
std::shared_ptr<RRef> getOrCreateRRef(
|
||||
const RRefForkData& rfd,
|
||||
const TypePtr& type);
|
||||
c10::intrusive_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd, const TypePtr& type);
|
||||
|
||||
// Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
|
||||
// one.
|
||||
std::shared_ptr<OwnerRRef> getOrCreateOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const TypePtr& type);
|
||||
c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef(const RRefId& rrefId, const TypePtr& type);
|
||||
|
||||
// Create an empty owner rref of type.
|
||||
std::shared_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
|
||||
c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
|
||||
|
||||
std::shared_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
|
||||
c10::intrusive_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
|
||||
|
||||
// Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
|
||||
// making a remote call to self, which as for now, still goes through serde
|
||||
@ -92,9 +86,9 @@ class TORCH_API RRefContext {
|
||||
// and this could happen before the self remote call finishes. To prevent
|
||||
// that, this API adds the RRefId as a ForkId, which will then delete the
|
||||
// ForkId when the self remote is done.
|
||||
void addSelfAsFork(std::shared_ptr<OwnerRRef>& rref);
|
||||
void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref);
|
||||
|
||||
// Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the
|
||||
// Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the
|
||||
// ``OwnerRRef`` in a map to keep it alive.
|
||||
void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
|
||||
// Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the
|
||||
@ -106,19 +100,19 @@ class TORCH_API RRefContext {
|
||||
// py::object, deleting it require GIL. The call site should guarded it with
|
||||
// a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally
|
||||
// left out of this function to avoid creating dependency on pybind.
|
||||
std::shared_ptr<RRef> delForkOfOwner(
|
||||
c10::intrusive_ptr<RRef> delForkOfOwner(
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId);
|
||||
|
||||
// Invoked when pickling an RRef to setup child/fork properly
|
||||
RRefForkData prepareChildFork(const std::shared_ptr<RRef>& rref);
|
||||
RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref);
|
||||
// Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and
|
||||
// send RREF_CHILD_ACCEPT to the parent.
|
||||
// NB: forkId is necessary here as the rref could be an OwnerRRef
|
||||
void notifyOwnerAndParentOfFork(
|
||||
const ForkId& forkId,
|
||||
worker_id_t parent,
|
||||
const std::shared_ptr<RRef>& rref);
|
||||
const c10::intrusive_ptr<RRef>& rref);
|
||||
|
||||
// When a UserRRef is forked to another worker (user or owner), it is added
|
||||
// into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT
|
||||
@ -128,12 +122,12 @@ class TORCH_API RRefContext {
|
||||
// previously submitted rpc/remote calls are acked before sending out the
|
||||
// RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
|
||||
// soon.
|
||||
void addPendingChild(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
|
||||
void addPendingChild(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref);
|
||||
void delPendingChild(const ForkId& forkId);
|
||||
|
||||
// When a UserRRef is created, it is added into pendingUsers_ to be held alive
|
||||
// until it receives RREF_USER_ACCEPT from the owner.
|
||||
void addPendingUser(const ForkId& forkId, const std::shared_ptr<RRef>& rref);
|
||||
void addPendingUser(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref);
|
||||
void delPendingUser(const ForkId& forkId);
|
||||
|
||||
void delUser(
|
||||
@ -146,7 +140,7 @@ class TORCH_API RRefContext {
|
||||
private:
|
||||
RRefContext(std::shared_ptr<RpcAgent>);
|
||||
|
||||
std::shared_ptr<UserRRef> createUserRRef(
|
||||
c10::intrusive_ptr<UserRRef> createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
@ -162,7 +156,7 @@ class TORCH_API RRefContext {
|
||||
const std::shared_ptr<RpcAgent> agent_;
|
||||
mutable std::mutex mutex_;
|
||||
// Keep OwnerRRefs alive until there is no living UserRRefs.
|
||||
std::unordered_map<RRefId, std::shared_ptr<RRef>, RRefId::Hash> owners_;
|
||||
std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_;
|
||||
// A conditional variable to block getOwnerRRef() calls until the
|
||||
// corresponding OwnerRRef has been created and inserted into the owners_ map.
|
||||
// The method getOwnerRRef() is triggered by rref.to_here() messages. The
|
||||
@ -184,7 +178,7 @@ class TORCH_API RRefContext {
|
||||
RRefId::Hash>
|
||||
forks_;
|
||||
|
||||
// The follow two maps keep UserRRefs alive by holding a shared_ptr to the
|
||||
// The follow two maps keep UserRRefs alive by holding a intrusive_ptr to the
|
||||
// RRef instances. A UserRRef must be added into this map if any of the
|
||||
// following two conditions is true:
|
||||
//
|
||||
@ -193,7 +187,7 @@ class TORCH_API RRefContext {
|
||||
// It can be used or shared, but cannot be deleted, and hence kept alive
|
||||
// in this map. A message of type RREF_USER_ACCEPT will remove the
|
||||
// corresponding RRef from this map.
|
||||
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash> pendingUsers_;
|
||||
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> pendingUsers_;
|
||||
|
||||
// (2) A UserRRef has forked a child UserRRef which has not been accepted by
|
||||
// the owner yet.
|
||||
@ -201,7 +195,7 @@ class TORCH_API RRefContext {
|
||||
// In this case, this UserRRef cannot send out RREF_USER_DELETE message,
|
||||
// as it could potentially trigger the OwnerRRef been deleted before the
|
||||
// owner learns about the forked child.
|
||||
std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash>
|
||||
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
|
||||
pendingChildren_;
|
||||
|
||||
std::mutex destroyedMutex_;
|
||||
|
@ -77,7 +77,7 @@ const ForkId& UserRRef::forkId() const {
|
||||
return forkId_;
|
||||
}
|
||||
|
||||
std::vector<IValue> UserRRef::toHere() {
|
||||
IValue UserRRef::toHere() {
|
||||
auto agent = RpcAgent::getCurrentRpcAgent();
|
||||
|
||||
// ScriptRRefFetchCall message always carries autograd context id even if
|
||||
@ -107,7 +107,13 @@ std::vector<IValue> UserRRef::toHere() {
|
||||
"or PYTHON_RREF_FETCH_RET");
|
||||
RpcCommandBase& rpc = *response;
|
||||
auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
|
||||
return rrefFetchRet.values();
|
||||
if (isPyObj()) {
|
||||
// wrap python serialized vector of ivalues into tuple, this
|
||||
// made the C++ toHere interface to return single IValue
|
||||
return ivalue::Tuple::create(rrefFetchRet.values());
|
||||
} else {
|
||||
return rrefFetchRet.values().front();
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////// OwnerRRef /////////////////////////////////////
|
||||
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/rref_interface.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_interface.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
|
||||
#include <atomic>
|
||||
@ -27,12 +27,13 @@ struct TORCH_API RRefForkData {
|
||||
|
||||
RRefForkData(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId_,
|
||||
const ForkId& forkId_,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
worker_id_t parent,
|
||||
std::string typeStr);
|
||||
};
|
||||
|
||||
|
||||
// Note [RRef Protocol]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
//
|
||||
@ -198,7 +199,7 @@ class TORCH_API RRef : public RRefInterface {
|
||||
inline bool isPyObj() {
|
||||
return type_ == PyObjectType::get();
|
||||
}
|
||||
inline const TypePtr type() {
|
||||
inline const TypePtr type() const override{
|
||||
return type_;
|
||||
}
|
||||
|
||||
@ -228,6 +229,8 @@ class TORCH_API UserRRef final : public RRef {
|
||||
UserRRef& operator=(const UserRRef& other) = delete;
|
||||
UserRRef& operator=(UserRRef&& other) = delete;
|
||||
|
||||
UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId, TypePtr type);
|
||||
|
||||
inline bool isOwner() const override {
|
||||
return false;
|
||||
}
|
||||
@ -237,7 +240,7 @@ class TORCH_API UserRRef final : public RRef {
|
||||
|
||||
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
|
||||
// yet, this call will block.
|
||||
std::vector<IValue> toHere();
|
||||
IValue toHere();
|
||||
|
||||
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
|
||||
~UserRRef() override;
|
||||
@ -245,12 +248,6 @@ class TORCH_API UserRRef final : public RRef {
|
||||
private:
|
||||
friend class RRefContext;
|
||||
|
||||
UserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
TypePtr type);
|
||||
|
||||
const ForkId forkId_;
|
||||
};
|
||||
|
||||
@ -263,6 +260,15 @@ class TORCH_API OwnerRRef final : public RRef {
|
||||
OwnerRRef& operator=(const OwnerRRef& other) = delete;
|
||||
OwnerRRef& operator=(OwnerRRef&& other) = delete;
|
||||
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
||||
: OwnerRRef(ownerId, rrefId, type, {}) {}
|
||||
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type, c10::optional<IValue> value)
|
||||
: RRef(ownerId, rrefId, std::move(type)) {
|
||||
value_ = std::move(value);
|
||||
}
|
||||
|
||||
|
||||
inline bool isOwner() const override {
|
||||
return true;
|
||||
}
|
||||
@ -284,18 +290,6 @@ class TORCH_API OwnerRRef final : public RRef {
|
||||
private:
|
||||
friend class RRefContext;
|
||||
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
||||
: OwnerRRef(ownerId, rrefId, type, {}) {}
|
||||
|
||||
OwnerRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
TypePtr type,
|
||||
c10::optional<IValue> value)
|
||||
: RRef(ownerId, rrefId, std::move(type)) {
|
||||
value_ = std::move(value);
|
||||
}
|
||||
|
||||
c10::optional<IValue> value_;
|
||||
mutable std::mutex mutex_;
|
||||
mutable std::condition_variable valueCV_;
|
||||
|
@ -49,7 +49,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript(
|
||||
return futPtr;
|
||||
}
|
||||
|
||||
std::shared_ptr<UserRRef> remoteTorchscript(
|
||||
c10::intrusive_ptr<UserRRef> remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
|
@ -25,7 +25,7 @@ c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript(
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
std::vector<c10::IValue>& stack);
|
||||
|
||||
std::shared_ptr<UserRRef> TORCH_API remoteTorchscript(
|
||||
c10::intrusive_ptr<UserRRef> TORCH_API remoteTorchscript(
|
||||
const std::string& dstWorkerName,
|
||||
const c10::QualifiedName& qualifiedName,
|
||||
const c10::FunctionSchema& functionSchema,
|
||||
|
Reference in New Issue
Block a user