mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[rpc] Switch RRef to be managed by intrusive_ptr (#33189)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33189 Add RRefInterface to Aten/Core, which will later be used by IValue Switch all the rpc code base to use intrusive_ptr instead of shared_ptr, so that we could add it to IValue. Actual adding to IValue and JIT will be in next PR Test Plan: Imported from OSS Differential Revision: D19871241 Pulled By: wanchaol fbshipit-source-id: d7e1fd04b46320e0f26c18591b49c92ad30a4032
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							cb4e6d025a
						
					
				
				
					commit
					9ae4d38a21
				
			| @ -1,14 +1,16 @@ | ||||
| #pragma once | ||||
| 
 | ||||
| #include <torch/csrc/distributed/rpc/types.h> | ||||
| #include <c10/util/intrusive_ptr.h> | ||||
| 
 | ||||
| namespace torch { | ||||
| namespace distributed { | ||||
| namespace rpc { | ||||
| namespace c10 { | ||||
| 
 | ||||
| struct Type; | ||||
| using TypePtr = std::shared_ptr<Type>; | ||||
| using worker_id_t = int16_t; | ||||
| 
 | ||||
| // This abstract class contains only user-facing APIs, and will be shared
 | ||||
| // between jit and distributed to implement TorchScript support.
 | ||||
| class RRefInterface { | ||||
| class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target { | ||||
|  public: | ||||
|   RRefInterface() = default; | ||||
|   // RRef is made NOT copyable NOT movable to prevent messing up reference
 | ||||
| @ -24,8 +26,8 @@ class RRefInterface { | ||||
| 
 | ||||
|   // Returns true if this is the ``OwnerRRef``
 | ||||
|   virtual bool isOwner() const = 0; | ||||
| 
 | ||||
|   virtual const TypePtr type() const = 0; | ||||
| }; | ||||
| 
 | ||||
| } // namespace rpc
 | ||||
| } // namespace distributed
 | ||||
| } // namespace torch
 | ||||
| } | ||||
| @ -59,7 +59,7 @@ RRefForkData fromPyTuple(const py::tuple& pyTuple) { | ||||
|  | ||||
| ///////////////////////////  PyRRef  ////////////////////////////////// | ||||
|  | ||||
| PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) { | ||||
| PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) { | ||||
|   TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); | ||||
| } | ||||
|  | ||||
| @ -87,18 +87,15 @@ py::object PyRRef::toHere() { | ||||
|   } else { | ||||
|     // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds | ||||
|     // a python object | ||||
|     std::vector<IValue> rawValues = | ||||
|         std::static_pointer_cast<UserRRef>(rref_)->toHere(); | ||||
|     IValue value; | ||||
|     IValue value = | ||||
|         c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(); | ||||
|     if (rref_->isPyObj()) { | ||||
|       value = jit::toIValue( | ||||
|           PythonRpcHandler::getInstance().deserialize( | ||||
|               SerializedPyObj::fromIValues(std::move(rawValues))), | ||||
|           PyObjectType::get()); | ||||
|       // python_rpc_handler deserialization will acquires GIL. | ||||
|       auto rfr_values = value.toTuple()->elements(); | ||||
|       return PythonRpcHandler::getInstance().deserialize( | ||||
|         SerializedPyObj::fromIValues(rfr_values) | ||||
|       ); | ||||
|     } else { | ||||
|       value = std::move(rawValues).front(); | ||||
|     } | ||||
|     { | ||||
|       // acquiring GIL as torch::jit::toPyObject creates new py::object | ||||
|       // without grabbing the GIL. | ||||
|       pybind11::gil_scoped_acquire ag; | ||||
| @ -114,7 +111,7 @@ py::object PyRRef::localValue() { | ||||
|       owner().name_); | ||||
|  | ||||
|   py::object res; | ||||
|   auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue(); | ||||
|   auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue(); | ||||
|   auto& rpcHandler = PythonRpcHandler::getInstance(); | ||||
|   { | ||||
|     // acquiring GIL as torch::jit::toPyObject creates new py::object without | ||||
| @ -131,8 +128,8 @@ std::string PyRRef::str() const { | ||||
|   if (rref_->isOwner()) { | ||||
|     ss << "OwnerRRef(" << rref_->rrefId() << ")"; | ||||
|   } else { | ||||
|     ss << "UserRRef(RRefId = " << rref_->rrefId() | ||||
|        << ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId() | ||||
|     ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = " | ||||
|        << c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId() | ||||
|        << ")"; | ||||
|   } | ||||
|   return ss.str(); | ||||
| @ -151,10 +148,9 @@ py::tuple PyRRef::pickle() const { | ||||
| PyRRef PyRRef::unpickle(const py::tuple& pyTuple) { | ||||
|   auto& ctx = RRefContext::getInstance(); | ||||
|   auto rrefForkData = fromPyTuple(pyTuple); | ||||
|   std::shared_ptr<RRef> rref = nullptr; | ||||
|   TypePtr rrefType = | ||||
|       PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_); | ||||
|   rref = ctx.getOrCreateRRef(rrefForkData, rrefType); | ||||
|   c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType); | ||||
|   ctx.notifyOwnerAndParentOfFork( | ||||
|       rrefForkData.forkId_, rrefForkData.parent_, rref); | ||||
|   return PyRRef(std::move(rref)); | ||||
|  | ||||
| @ -12,9 +12,8 @@ namespace rpc { | ||||
| // pickle and unpickle. | ||||
| class PyRRef { | ||||
|  public: | ||||
|   explicit PyRRef(std::shared_ptr<RRef> rref); | ||||
|   // creates a local RRef with the given object as value | ||||
|   explicit PyRRef(const py::object& value); | ||||
|   explicit PyRRef(c10::intrusive_ptr<RRef> rref); | ||||
|  | ||||
|   bool isOwner() const; | ||||
|   WorkerInfo owner() const; | ||||
| @ -25,7 +24,7 @@ class PyRRef { | ||||
|   static PyRRef unpickle(const py::tuple& t); | ||||
|  | ||||
|  private: | ||||
|   std::shared_ptr<RRef> rref_; | ||||
|   c10::intrusive_ptr<RRef> rref_; | ||||
| }; | ||||
|  | ||||
| } // namespace rpc | ||||
|  | ||||
| @ -156,7 +156,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc( | ||||
|     case MessageType::SCRIPT_RREF_FETCH_CALL: { | ||||
|       auto& srf = static_cast<ScriptRRefFetchCall&>(rpc); | ||||
|       auto& ctx = RRefContext::getInstance(); | ||||
|       std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId()); | ||||
|       c10::intrusive_ptr<OwnerRRef> rref = | ||||
|           ctx.getOwnerRRef(srf.rrefId()); | ||||
|       if (rref->hasValue()) { // optional fast-path | ||||
|         return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage()); | ||||
|       } | ||||
| @ -181,7 +182,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc( | ||||
|     case MessageType::PYTHON_RREF_FETCH_CALL: { | ||||
|       auto& prf = static_cast<PythonRRefFetchCall&>(rpc); | ||||
|       auto& ctx = RRefContext::getInstance(); | ||||
|       std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(prf.rrefId()); | ||||
|       c10::intrusive_ptr<OwnerRRef> rref = | ||||
|           ctx.getOwnerRRef(prf.rrefId()); | ||||
|       if (rref->hasValue()) { // optional fast-path | ||||
|         auto value = rref->getValue(); | ||||
|         py::object pyValue; | ||||
|  | ||||
| @ -28,7 +28,7 @@ RRefContext& RRefContext::getInstance() { | ||||
|   return *context; | ||||
| } | ||||
|  | ||||
| std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance( | ||||
| std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance( | ||||
|     bool ignoreRRefLeak) { | ||||
|   auto& ctx = RRefContext::getInstance(); | ||||
|   { | ||||
| @ -36,7 +36,7 @@ std::vector<std::shared_ptr<RRef>> RRefContext::destroyInstance( | ||||
|     ctx.destroyed_ = true; | ||||
|   } | ||||
|   ctx.checkRRefLeaks(ignoreRRefLeak); | ||||
|   std::vector<std::shared_ptr<RRef>> deletedRRefs; | ||||
|   std::vector<c10::intrusive_ptr<RRef>> deletedRRefs; | ||||
|   for (auto& entry : ctx.owners_) { | ||||
|     auto rref = entry.second; | ||||
|     if (rref->isPyObj()) { | ||||
| @ -105,9 +105,7 @@ void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::shared_ptr<UserRRef> RRefContext::createUserRRef( | ||||
|     worker_id_t ownerId, | ||||
|     const TypePtr& type) { | ||||
| c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(worker_id_t ownerId, const TypePtr& type) { | ||||
|   TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner."); | ||||
|   // Explicitly creating rrefId before forkId to make sure the order is | ||||
|   // deterministic, as the argument evaluation order is system and compiler | ||||
| @ -117,7 +115,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef( | ||||
|   return createUserRRef(ownerId, rrefId, forkId, type); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<UserRRef> RRefContext::createUserRRef( | ||||
| c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef( | ||||
|     worker_id_t ownerId, | ||||
|     const RRefId& rrefId, | ||||
|     const ForkId& forkId, | ||||
| @ -136,7 +134,7 @@ std::shared_ptr<UserRRef> RRefContext::createUserRRef( | ||||
|   // The reason for not adding the pending user here is to put addPendingUser() | ||||
|   // close to where the RPC occurs, and it is more clear to pair it with | ||||
|   // deletePendingUser() in the response callback at the call site. | ||||
|   return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId, type)); | ||||
|   return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type); | ||||
| } | ||||
|  | ||||
| void RRefContext::delUser( | ||||
| @ -156,7 +154,7 @@ void RRefContext::delUser( | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::shared_ptr<RRef> RRefContext::getOrCreateRRef( | ||||
| c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef( | ||||
|     const RRefForkData& rrefForkData, | ||||
|     const TypePtr& type) { | ||||
|   auto& ownerId = rrefForkData.ownerId_; | ||||
| @ -171,7 +169,7 @@ std::shared_ptr<RRef> RRefContext::getOrCreateRRef( | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef( | ||||
| c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef( | ||||
|     const RRefId& rrefId, | ||||
|     const TypePtr& type) { | ||||
|   std::lock_guard<std::mutex> lock(mutex_); | ||||
| @ -182,41 +180,40 @@ std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef( | ||||
|     // NB: cannot use make_shared here as the constructor of OwnerRRef is | ||||
|     // private. | ||||
|     auto rref = | ||||
|         std::shared_ptr<OwnerRRef>(new OwnerRRef(getWorkerId(), rrefId, type)); | ||||
|         c10::make_intrusive<OwnerRRef>(getWorkerId(), rrefId, type); | ||||
|     owners_[rref->rrefId()] = rref; | ||||
|     ownerCV_.notify_all(); | ||||
|     return rref; | ||||
|   } else { | ||||
|     // Scenario (2) retrieving an existing RRef | ||||
|     auto ownerRRef = std::static_pointer_cast<OwnerRRef>(iter->second); | ||||
|     auto ownerRRef = c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second); | ||||
|     TORCH_INTERNAL_ASSERT(ownerRRef->type() == type); | ||||
|     return ownerRRef; | ||||
|   } | ||||
| } | ||||
|  | ||||
| std::shared_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) { | ||||
| c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) { | ||||
|   // Don't add this OnwerRRef to the owners_ map yet, otherwise | ||||
|   // it will never be removed from there. Instead, only add it to the | ||||
|   // map in prepareChildFork, in case this local RRef is being passed | ||||
|   // to another worker. | ||||
|   return std::shared_ptr<OwnerRRef>( | ||||
|       new OwnerRRef(getWorkerId(), genGloballyUniqueId(), type)); | ||||
|   return c10::make_intrusive<OwnerRRef>(getWorkerId(), genGloballyUniqueId(), type); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) { | ||||
| c10::intrusive_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) { | ||||
|   std::unique_lock<std::mutex> lock(mutex_); | ||||
|   const auto iter = owners_.find(rrefId); | ||||
|   if (iter == owners_.end()) { | ||||
|     // Scenario (1) RRef is used before it is created | ||||
|     ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); }); | ||||
|     return std::static_pointer_cast<OwnerRRef>(owners_[rrefId]); | ||||
|     return c10::static_intrusive_pointer_cast<OwnerRRef>(owners_[rrefId]); | ||||
|   } else { | ||||
|     // Scenario (2) retrieving an existing RRef | ||||
|     return std::static_pointer_cast<OwnerRRef>(iter->second); | ||||
|     return c10::static_intrusive_pointer_cast<OwnerRRef>(iter->second); | ||||
|   } | ||||
| } | ||||
|  | ||||
| RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) { | ||||
| RRefForkData RRefContext::prepareChildFork(const c10::intrusive_ptr<RRef>& rref) { | ||||
|   auto rrefForkData = rref->fork(); | ||||
|   if (rref->isOwner()) { | ||||
|     // Note [Early Fork Registration] | ||||
| @ -256,7 +253,7 @@ RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) { | ||||
| void RRefContext::notifyOwnerAndParentOfFork( | ||||
|     const ForkId& forkId, | ||||
|     worker_id_t parent, | ||||
|     const std::shared_ptr<RRef>& rref) { | ||||
|     const c10::intrusive_ptr<RRef>& rref) { | ||||
|   if (parent == rref->owner()) { | ||||
|     if (parent == agent_->getWorkerInfo().id_) { | ||||
|       // Owner sending RRef to self, remove the forkId as it was added during | ||||
| @ -310,7 +307,7 @@ void RRefContext::notifyOwnerAndParentOfFork( | ||||
|  | ||||
| void RRefContext::addPendingChild( | ||||
|     const ForkId& forkId, | ||||
|     const std::shared_ptr<RRef>& rref) { | ||||
|     const c10::intrusive_ptr<RRef>& rref) { | ||||
|   // see Note [Early Fork Registration] | ||||
|   // If the parent is the owner, it should directly add the child UserRRef as a | ||||
|   // fork. | ||||
| @ -334,7 +331,9 @@ void RRefContext::delPendingChild(const ForkId& forkId) { | ||||
|  | ||||
| void RRefContext::addPendingUser( | ||||
|     const ForkId& forkId, | ||||
|     const std::shared_ptr<RRef>& rref) { | ||||
|     const c10::intrusive_ptr<RRef>& rref) { | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User."); | ||||
|   std::lock_guard<std::mutex> lock(mutex_); | ||||
|   TORCH_INTERNAL_ASSERT( | ||||
|       pendingUsers_.find(forkId) == pendingUsers_.end(), | ||||
| @ -362,7 +361,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef>& rref) { | ||||
| void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) { | ||||
|   std::lock_guard<std::mutex> lock(mutex_); | ||||
|   const auto& rrefId = rref->rrefId(); | ||||
|   owners_[rrefId] = rref; | ||||
| @ -384,10 +383,10 @@ void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) { | ||||
|   rrefForks.insert(forkId); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<RRef> RRefContext::delForkOfOwner( | ||||
| c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner( | ||||
|     const RRefId& rrefId, | ||||
|     const ForkId& forkId) { | ||||
|   std::shared_ptr<RRef> deletedRRef = nullptr; | ||||
|   c10::intrusive_ptr<RRef> deletedRRef; | ||||
|   { | ||||
|     std::lock_guard<std::mutex> lock(mutex_); | ||||
|     auto rrefIter = forks_.find(rrefId); | ||||
|  | ||||
| @ -28,7 +28,7 @@ class TORCH_API RRefContext { | ||||
|   // hold py::object. The call-site is also responsible for resetting those | ||||
|   // shared_ptr objects with a GIL. See comments at delForkOfOwner() for more | ||||
|   // details. | ||||
|   static std::vector<std::shared_ptr<RRef>> destroyInstance( | ||||
|   static std::vector<c10::intrusive_ptr<RRef>> destroyInstance( | ||||
|       bool ignoreRRefLeak = true); | ||||
|  | ||||
|   static void handleException(const c10::optional<utils::FutureError>& futErr); | ||||
| @ -60,27 +60,21 @@ class TORCH_API RRefContext { | ||||
|   } | ||||
|  | ||||
|   // create a ``UserRRef`` owned by the worker ``ownerId`` | ||||
|   std::shared_ptr<UserRRef> createUserRRef( | ||||
|       worker_id_t ownerId, | ||||
|       const TypePtr& type); | ||||
|   c10::intrusive_ptr<UserRRef> createUserRRef(worker_id_t ownerId, const TypePtr& type); | ||||
|  | ||||
|   // Convert an RRefForkData into an RRef. This RRef could be user or owner. | ||||
|   // This RRef could have already existed before, or could be created in this | ||||
|   // method, we pass type here to validate or help the rref creation. | ||||
|   std::shared_ptr<RRef> getOrCreateRRef( | ||||
|       const RRefForkData& rfd, | ||||
|       const TypePtr& type); | ||||
|   c10::intrusive_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd, const TypePtr& type); | ||||
|  | ||||
|   // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new | ||||
|   // one. | ||||
|   std::shared_ptr<OwnerRRef> getOrCreateOwnerRRef( | ||||
|       const RRefId& rrefId, | ||||
|       const TypePtr& type); | ||||
|   c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef(const RRefId& rrefId, const TypePtr& type); | ||||
|  | ||||
|   // Create an empty owner rref of type. | ||||
|   std::shared_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type); | ||||
|   c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type); | ||||
|  | ||||
|   std::shared_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId); | ||||
|   c10::intrusive_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId); | ||||
|  | ||||
|   // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when | ||||
|   // making a remote call to self, which as for now, still goes through serde | ||||
| @ -92,9 +86,9 @@ class TORCH_API RRefContext { | ||||
|   // and this could happen before the self remote call finishes. To prevent | ||||
|   // that, this API adds the RRefId as a ForkId, which will then delete the | ||||
|   // ForkId when the self remote is done. | ||||
|   void addSelfAsFork(std::shared_ptr<OwnerRRef>& rref); | ||||
|   void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref); | ||||
|  | ||||
|   // Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the | ||||
|   // Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the | ||||
|   // ``OwnerRRef`` in a map to keep it alive. | ||||
|   void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId); | ||||
|   // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the | ||||
| @ -106,19 +100,19 @@ class TORCH_API RRefContext { | ||||
|   // py::object, deleting it require GIL. The call site should guarded it with | ||||
|   // a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally | ||||
|   // left out of this function to avoid creating dependency on pybind. | ||||
|   std::shared_ptr<RRef> delForkOfOwner( | ||||
|   c10::intrusive_ptr<RRef> delForkOfOwner( | ||||
|       const RRefId& rrefId, | ||||
|       const ForkId& forkId); | ||||
|  | ||||
|   // Invoked when pickling an RRef to setup child/fork properly | ||||
|   RRefForkData prepareChildFork(const std::shared_ptr<RRef>& rref); | ||||
|   RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref); | ||||
|   // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and | ||||
|   // send RREF_CHILD_ACCEPT to the parent. | ||||
|   // NB: forkId is necessary here as the rref could be an OwnerRRef | ||||
|   void notifyOwnerAndParentOfFork( | ||||
|       const ForkId& forkId, | ||||
|       worker_id_t parent, | ||||
|       const std::shared_ptr<RRef>& rref); | ||||
|       const c10::intrusive_ptr<RRef>& rref); | ||||
|  | ||||
|   // When a UserRRef is forked to another worker (user or owner), it is added | ||||
|   // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT | ||||
| @ -128,12 +122,12 @@ class TORCH_API RRefContext { | ||||
|   // previously submitted rpc/remote calls are acked before sending out the | ||||
|   // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too | ||||
|   // soon. | ||||
|   void addPendingChild(const ForkId& forkId, const std::shared_ptr<RRef>& rref); | ||||
|   void addPendingChild(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref); | ||||
|   void delPendingChild(const ForkId& forkId); | ||||
|  | ||||
|   // When a UserRRef is created, it is added into pendingUsers_ to be held alive | ||||
|   // until it receives RREF_USER_ACCEPT from the owner. | ||||
|   void addPendingUser(const ForkId& forkId, const std::shared_ptr<RRef>& rref); | ||||
|   void addPendingUser(const ForkId& forkId, const c10::intrusive_ptr<RRef>& rref); | ||||
|   void delPendingUser(const ForkId& forkId); | ||||
|  | ||||
|   void delUser( | ||||
| @ -146,7 +140,7 @@ class TORCH_API RRefContext { | ||||
|  private: | ||||
|   RRefContext(std::shared_ptr<RpcAgent>); | ||||
|  | ||||
|   std::shared_ptr<UserRRef> createUserRRef( | ||||
|   c10::intrusive_ptr<UserRRef> createUserRRef( | ||||
|       worker_id_t ownerId, | ||||
|       const RRefId& rrefId, | ||||
|       const ForkId& forkId, | ||||
| @ -162,7 +156,7 @@ class TORCH_API RRefContext { | ||||
|   const std::shared_ptr<RpcAgent> agent_; | ||||
|   mutable std::mutex mutex_; | ||||
|   // Keep OwnerRRefs alive until there is no living UserRRefs. | ||||
|   std::unordered_map<RRefId, std::shared_ptr<RRef>, RRefId::Hash> owners_; | ||||
|   std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_; | ||||
|   // A conditional variable to block getOwnerRRef() calls until the | ||||
|   // corresponding OwnerRRef has been created and inserted into the owners_ map. | ||||
|   // The method getOwnerRRef() is triggered by rref.to_here() messages. The | ||||
| @ -184,7 +178,7 @@ class TORCH_API RRefContext { | ||||
|       RRefId::Hash> | ||||
|       forks_; | ||||
|  | ||||
|   // The follow two maps keep UserRRefs alive by holding a shared_ptr to the | ||||
|   // The follow two maps keep UserRRefs alive by holding a intrusive_ptr to the | ||||
|   // RRef instances. A UserRRef must be added into this map if any of the | ||||
|   // following two conditions is true: | ||||
|   // | ||||
| @ -193,7 +187,7 @@ class TORCH_API RRefContext { | ||||
|   //     It can be used or shared, but cannot be deleted, and hence kept alive | ||||
|   //     in this map. A message of type RREF_USER_ACCEPT will remove the | ||||
|   //     corresponding RRef from this map. | ||||
|   std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash> pendingUsers_; | ||||
|   std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> pendingUsers_; | ||||
|  | ||||
|   // (2) A UserRRef has forked a child UserRRef which has not been accepted by | ||||
|   //     the owner yet. | ||||
| @ -201,7 +195,7 @@ class TORCH_API RRefContext { | ||||
|   //     In this case, this UserRRef cannot send out RREF_USER_DELETE message, | ||||
|   //     as it could potentially trigger the OwnerRRef been deleted before the | ||||
|   //     owner learns about the forked child. | ||||
|   std::unordered_map<ForkId, std::shared_ptr<RRef>, ForkId::Hash> | ||||
|   std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> | ||||
|       pendingChildren_; | ||||
|  | ||||
|   std::mutex destroyedMutex_; | ||||
|  | ||||
| @ -77,7 +77,7 @@ const ForkId& UserRRef::forkId() const { | ||||
|   return forkId_; | ||||
| } | ||||
|  | ||||
| std::vector<IValue> UserRRef::toHere() { | ||||
| IValue UserRRef::toHere() { | ||||
|   auto agent = RpcAgent::getCurrentRpcAgent(); | ||||
|  | ||||
|   // ScriptRRefFetchCall message always carries autograd context id even if | ||||
| @ -107,7 +107,13 @@ std::vector<IValue> UserRRef::toHere() { | ||||
|       "or PYTHON_RREF_FETCH_RET"); | ||||
|   RpcCommandBase& rpc = *response; | ||||
|   auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc); | ||||
|   return rrefFetchRet.values(); | ||||
|   if (isPyObj()) { | ||||
|     // wrap python serialized vector of ivalues into tuple, this | ||||
|     // made the C++ toHere interface to return single IValue | ||||
|     return ivalue::Tuple::create(rrefFetchRet.values()); | ||||
|   } else { | ||||
|     return rrefFetchRet.values().front(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| //////////////////////////  OwnerRRef  ///////////////////////////////////// | ||||
|  | ||||
| @ -1,10 +1,10 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/core/jit_type.h> | ||||
| #include <ATen/core/rref_interface.h> | ||||
| #include <c10/util/Optional.h> | ||||
| #include <torch/csrc/distributed/rpc/message.h> | ||||
| #include <torch/csrc/distributed/rpc/rpc_agent.h> | ||||
| #include <torch/csrc/distributed/rpc/rref_interface.h> | ||||
| #include <torch/csrc/distributed/rpc/types.h> | ||||
|  | ||||
| #include <atomic> | ||||
| @ -27,12 +27,13 @@ struct TORCH_API RRefForkData { | ||||
|  | ||||
|   RRefForkData( | ||||
|       worker_id_t ownerId, | ||||
|       const RRefId& rrefId_, | ||||
|       const ForkId& forkId_, | ||||
|       const RRefId& rrefId, | ||||
|       const ForkId& forkId, | ||||
|       worker_id_t parent, | ||||
|       std::string typeStr); | ||||
| }; | ||||
|  | ||||
|  | ||||
| // Note [RRef Protocol] | ||||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| // | ||||
| @ -198,7 +199,7 @@ class TORCH_API RRef : public RRefInterface { | ||||
|   inline bool isPyObj() { | ||||
|     return type_ == PyObjectType::get(); | ||||
|   } | ||||
|   inline const TypePtr type() { | ||||
|   inline const TypePtr type() const override{ | ||||
|     return type_; | ||||
|   } | ||||
|  | ||||
| @ -228,6 +229,8 @@ class TORCH_API UserRRef final : public RRef { | ||||
|   UserRRef& operator=(const UserRRef& other) = delete; | ||||
|   UserRRef& operator=(UserRRef&& other) = delete; | ||||
|  | ||||
|   UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId, TypePtr type); | ||||
|  | ||||
|   inline bool isOwner() const override { | ||||
|     return false; | ||||
|   } | ||||
| @ -237,7 +240,7 @@ class TORCH_API UserRRef final : public RRef { | ||||
|  | ||||
|   // Get of copy of the value from the ``OwnerRRef``. If the value is not ready | ||||
|   // yet, this call will block. | ||||
|   std::vector<IValue> toHere(); | ||||
|   IValue toHere(); | ||||
|  | ||||
|   // Upon destruction, this ``UserRRef`` will tell the owner to deref. | ||||
|   ~UserRRef() override; | ||||
| @ -245,12 +248,6 @@ class TORCH_API UserRRef final : public RRef { | ||||
|  private: | ||||
|   friend class RRefContext; | ||||
|  | ||||
|   UserRRef( | ||||
|       worker_id_t ownerId, | ||||
|       const RRefId& rrefId, | ||||
|       const ForkId& forkId, | ||||
|       TypePtr type); | ||||
|  | ||||
|   const ForkId forkId_; | ||||
| }; | ||||
|  | ||||
| @ -263,6 +260,15 @@ class TORCH_API OwnerRRef final : public RRef { | ||||
|   OwnerRRef& operator=(const OwnerRRef& other) = delete; | ||||
|   OwnerRRef& operator=(OwnerRRef&& other) = delete; | ||||
|  | ||||
|   OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type) | ||||
|       : OwnerRRef(ownerId, rrefId, type, {}) {} | ||||
|  | ||||
|   OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type, c10::optional<IValue> value) | ||||
|       : RRef(ownerId, rrefId, std::move(type)) { | ||||
|     value_ = std::move(value); | ||||
|   } | ||||
|  | ||||
|  | ||||
|   inline bool isOwner() const override { | ||||
|     return true; | ||||
|   } | ||||
| @ -284,18 +290,6 @@ class TORCH_API OwnerRRef final : public RRef { | ||||
|  private: | ||||
|   friend class RRefContext; | ||||
|  | ||||
|   OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type) | ||||
|       : OwnerRRef(ownerId, rrefId, type, {}) {} | ||||
|  | ||||
|   OwnerRRef( | ||||
|       worker_id_t ownerId, | ||||
|       const RRefId& rrefId, | ||||
|       TypePtr type, | ||||
|       c10::optional<IValue> value) | ||||
|       : RRef(ownerId, rrefId, std::move(type)) { | ||||
|     value_ = std::move(value); | ||||
|   } | ||||
|  | ||||
|   c10::optional<IValue> value_; | ||||
|   mutable std::mutex mutex_; | ||||
|   mutable std::condition_variable valueCV_; | ||||
|  | ||||
| @ -49,7 +49,7 @@ c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscript( | ||||
|   return futPtr; | ||||
| } | ||||
|  | ||||
| std::shared_ptr<UserRRef> remoteTorchscript( | ||||
| c10::intrusive_ptr<UserRRef> remoteTorchscript( | ||||
|     const std::string& dstWorkerName, | ||||
|     const c10::QualifiedName& qualifiedName, | ||||
|     const c10::FunctionSchema& functionSchema, | ||||
|  | ||||
| @ -25,7 +25,7 @@ c10::intrusive_ptr<c10::ivalue::Future> TORCH_API rpcTorchscript( | ||||
|     const c10::FunctionSchema& functionSchema, | ||||
|     std::vector<c10::IValue>& stack); | ||||
|  | ||||
| std::shared_ptr<UserRRef> TORCH_API remoteTorchscript( | ||||
| c10::intrusive_ptr<UserRRef> TORCH_API remoteTorchscript( | ||||
|     const std::string& dstWorkerName, | ||||
|     const c10::QualifiedName& qualifiedName, | ||||
|     const c10::FunctionSchema& functionSchema, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user