#include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { ///////////////////// Pickle/Unpickle Helplers //////////////////////////// namespace { py::tuple toPyTuple(const RRefForkData& rrefForkData) { // add GIL as it is contructing a py::object pybind11::gil_scoped_acquire ag; return py::make_tuple( rrefForkData.ownerId_, rrefForkData.rrefId_.createdOn_, rrefForkData.rrefId_.localId_, rrefForkData.forkId_.createdOn_, rrefForkData.forkId_.localId_, rrefForkData.parent_, rrefForkData.typeStr_); } RRefForkData fromPyTuple(const py::tuple& pyTuple) { // add GIL as it is accessing a py::object pybind11::gil_scoped_acquire ag; TORCH_INTERNAL_ASSERT( pyTuple.size() == RFD_TUPLE_SIZE, "Pickled RRefForkData must contain ", RFD_TUPLE_SIZE, " numbers."); worker_id_t ownerId = pyTuple[OWNER_IDX].cast(); // const reference will extend the lifetime of the temporary variable const RRefId& rrefId = RRefId( pyTuple[RREFID_ON_IDX].cast(), pyTuple[RREFID_ID_IDX].cast()); const RRefId& forkId = RRefId( pyTuple[FORKID_ON_IDX].cast(), pyTuple[FORKID_ID_IDX].cast()); worker_id_t parent = pyTuple[PARENT_IDX].cast(); const std::string& typeStr = pyTuple[TYPE_IDX].cast(); return RRefForkData(ownerId, rrefId, forkId, parent, typeStr); } TypePtr tryInferTypeWithTypeHint( const py::object& value, const py::object& type_hint) { // If the py::object to be contained by the RRef is a ScripModule, we enforce // users to specify its ModuleInterface type. if (auto module = jit::as_module(value)) { TORCH_CHECK( !type_hint.is_none(), "The RRef being created contains a ScriptModule, " "must provide its ModuleInterface type hint. "); c10::QualifiedName type_qualified_name = c10::QualifiedName( py::cast(py::module::import("torch.jit") .attr("_qualified_name")(type_hint))); TypePtr type_hint_ptr = jit::get_python_cu()->get_interface(type_qualified_name); std::ostringstream subtype_check_msg; TORCH_CHECK( type_hint_ptr != nullptr && module.value().type()->isSubtypeOfExt( type_hint_ptr, &subtype_check_msg), module.value().type()->python_str(), " is not a subtype of the type hint: ", type_qualified_name.qualifiedName(), ", did you pass a valid interface type?\n", subtype_check_msg.str()); return type_hint_ptr; } else { TORCH_CHECK( type_hint.is_none(), "type_hint should only be specified when the RRef being created contains a ScriptModule."); } // NB: `jit::tryToInferType(..)` infers types including ScriptClass, but // excluding ScripModule. jit::InferredType type_inferred = jit::tryToInferType(value); if (type_inferred.success()) { // If we could infer the type from the pyobject, we create // the RRef with the IValue of that type. return type_inferred.type(); } // Otherwise it's a pure pyobject, create the RRef // that holds an IValue of an pyobject. return PyObjectType::get(); } } // namespace /////////////////////////// PyRRef ////////////////////////////////// PyRRef::PyRRef(c10::intrusive_ptr rref) : rref_(std::move(rref)) { TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); } PyRRef::PyRRef(const py::object& value, const py::object& type_hint) : PyRRef([&value, &type_hint]() { TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint); auto rref = RRefContext::getInstance().createOwnerRRef(elem_type); // jit::toIValue takes a py::handle as the first argument, and it calls // py::handle.cast() to incref of provided value. The // returned ivalue will keep the reference alive. // NB: the first argument const py::object& value must be kept alive // until the following jit::toIValue returns (i.e., incref done). That's // why this ctor can only be called while holding GIL. IValue ivalue = jit::toIValue(value, elem_type); rref->setValue(std::move(ivalue)); return rref; }()) {} c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. return wrapFutureMessageInJitFuture( rref_->getOwnerCreationFuture(), false /* hasValue */); } bool PyRRef::isOwner() const { return rref_->isOwner(); } bool PyRRef::confirmedByOwner() const { return rref_->confirmedByOwner(); } WorkerInfo PyRRef::owner() const { return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner()); } std::string PyRRef::ownerName() const { return rref_->ownerName(); } py::object PyRRef::toHere() { if (rref_->isOwner()) { return localValue(); } else { // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds // a python object IValue value = c10::static_intrusive_pointer_cast(rref_)->toHere(); if (rref_->isPyObj()) { // python_rpc_handler deserialization will acquires GIL. auto rfr_values = value.toTuple()->elements(); auto& pythonRpcHandler = PythonRpcHandler::getInstance(); auto ret = pythonRpcHandler.deserialize( SerializedPyObj::fromIValues(rfr_values)); pythonRpcHandler.handleException(ret); return ret; } else { // acquiring GIL as torch::jit::toPyObject creates new py::object // without grabbing the GIL. pybind11::gil_scoped_acquire ag; return torch::jit::toPyObject(std::move(value)); } } } py::object PyRRef::localValue() { TORCH_CHECK( rref_->isOwner(), "Cannot call localValue() on a non-local reference. Call it on ", owner().name_); py::object res; auto value = c10::static_intrusive_pointer_cast(rref_)->getValue(); auto& rpcHandler = PythonRpcHandler::getInstance(); { // acquiring GIL as torch::jit::toPyObject creates new py::object without // grabbing the GIL. pybind11::gil_scoped_acquire ag; res = torch::jit::toPyObject(std::move(value)); rpcHandler.handleExceptionGILHeld(res); } return res; } std::string PyRRef::str() const { if (rref_->isOwner()) { return c10::str("OwnerRRef(", rref_->rrefId(), ")"); } else { return c10::str( "UserRRef(RRefId = ", rref_->rrefId(), ", ForkId = ", c10::static_intrusive_pointer_cast(rref_)->forkId(), ")"); } } py::object PyRRef::createRRefProxy(const RRefProxyType& type) const { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); pybind11::gil_scoped_acquire ag; auto& functions = pythonRpcHandler.getRRefProxyFunctions(); auto& ctor = functions.rrefProxyCtor_; switch (type) { case RRefProxyType::RPC_SYNC: { return ctor(*this, functions.rpcSync_); } case RRefProxyType::RPC_ASYNC: { return ctor(*this, functions.rpcAsync_); } case RRefProxyType::REMOTE: { return ctor(*this, functions.remote_); } default: { TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type); } } } py::tuple PyRRef::pickle() const { auto& ctx = RRefContext::getInstance(); auto rrefForkData = ctx.prepareChildFork(rref_); return toPyTuple(rrefForkData); } PyRRef PyRRef::unpickle(const py::tuple& pyTuple) { auto& ctx = RRefContext::getInstance(); auto rrefForkData = fromPyTuple(pyTuple); TypePtr rrefType = PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_); c10::intrusive_ptr rref = ctx.getOrCreateRRef(rrefForkData, rrefType); ctx.notifyOwnerAndParentOfFork( rrefForkData.forkId_, rrefForkData.parent_, rref); return PyRRef(std::move(rref)); } c10::IValue PyRRef::toIValue() { // cast to RRefInterface to hold it into IValue auto rrefPtr = c10::static_intrusive_pointer_cast(rref_); return IValue(rrefPtr); } } // namespace rpc } // namespace distributed } // namespace torch