#include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace rpc { ///////////////////// Pickle/Unpickle Helplers //////////////////////////// namespace { py::tuple toPyTuple(const RRefForkData& rrefForkData) { // add GIL as it is contructing a py::object pybind11::gil_scoped_acquire ag; return py::make_tuple( rrefForkData.ownerId_, rrefForkData.rrefId_.createdOn_, rrefForkData.rrefId_.localId_, rrefForkData.forkId_.createdOn_, rrefForkData.forkId_.localId_, rrefForkData.parent_, rrefForkData.typeStr_); } RRefForkData fromPyTuple(const py::tuple& pyTuple) { // add GIL as it is accessing a py::object pybind11::gil_scoped_acquire ag; TORCH_INTERNAL_ASSERT( pyTuple.size() == RFD_TUPLE_SIZE, "Pickled RRefForkData must contain ", RFD_TUPLE_SIZE, " numbers."); worker_id_t ownerId = pyTuple[OWNER_IDX].cast(); // const reference will extend the lifetime of the temporary variable const RRefId& rrefId = RRefId( pyTuple[RREFID_ON_IDX].cast(), pyTuple[RREFID_ID_IDX].cast()); const RRefId& forkId = RRefId( pyTuple[FORKID_ON_IDX].cast(), pyTuple[FORKID_ID_IDX].cast()); worker_id_t parent = pyTuple[PARENT_IDX].cast(); const std::string& typeStr = pyTuple[TYPE_IDX].cast(); return RRefForkData(ownerId, rrefId, forkId, parent, typeStr); } TypePtr tryInferTypeWithTypeHint( const py::object& value, const py::object& type_hint) { // If the py::object to be contained by the RRef is a ScripModule, we enforce // users to specify its ModuleInterface type. if (auto module = jit::as_module(value)) { TORCH_CHECK( !type_hint.is_none(), "The RRef being created contains a ScriptModule, " "must provide its ModuleInterface type hint. "); c10::QualifiedName type_qualified_name = c10::QualifiedName( py::cast(py::module::import("torch._jit_internal") .attr("_qualified_name")(type_hint))); TypePtr type_hint_ptr = jit::get_python_cu()->get_interface(type_qualified_name); std::ostringstream subtype_check_msg; TORCH_CHECK( type_hint_ptr != nullptr && module.value().type()->isSubtypeOfExt( *type_hint_ptr, &subtype_check_msg), module.value().type()->repr_str(), " is not a subtype of the type hint: ", type_qualified_name.qualifiedName(), ", did you pass a valid interface type?\n", subtype_check_msg.str()); return type_hint_ptr; } else { TORCH_CHECK( type_hint.is_none(), "type_hint should only be specified when the RRef being created contains a ScriptModule."); } // Check if value is an instance of a ScriptClass. If not, skip type inference // because it will try to script the class that value is in instance of, and // this should be avoided. py::bool_ can_compile = py::module::import("torch._jit_internal") .attr("can_compile_class")(value.get_type()); if (py::cast(can_compile)) { py::object existing_ty = py::module::import("torch.jit._state") .attr("_get_script_class")(value.get_type()); if (existing_ty.is_none()) { return PyObjectType::get(); } } // NB: `jit::tryToInferType(..)` infers types including ScriptClass, but // excluding ScripModule. jit::InferredType type_inferred = jit::tryToInferType(value); if (type_inferred.success()) { // If we could infer the type from the pyobject, we create // the RRef with the IValue of that type. return type_inferred.type(); } // Otherwise it's a pure pyobject, create the RRef // that holds an IValue of an pyobject. return PyObjectType::get(); } } // namespace /////////////////////////// PyRRef ////////////////////////////////// PyRRef::PyRRef(c10::intrusive_ptr rref) : rref_(std::move(rref)), profilingFuture_(c10::nullopt) { TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); } PyRRef::PyRRef(const py::object& value, const py::object& type_hint) : PyRRef([&value, &type_hint]() mutable { TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint); auto rref = RRefContext::getInstance().createOwnerRRef(elem_type); // jit::toIValue takes a py::handle as the first argument, and it calls // py::handle.cast() to incref of provided value. The // returned ivalue will keep the reference alive. // NB: the first argument const py::object& value must be kept alive // until the following jit::toIValue returns (i.e., incref done). That's // why this ctor can only be called while holding GIL. IValue ivalue = jit::toIValue(value, elem_type); rref->setValue(std::move(ivalue)); return rref; }()) {} PyRRef::~PyRRef() { if (type_.has_value()) { (*type_).dec_ref(); // explicitly setting PyObject* to nullptr to prevent py::object's dtor to // decref on the PyObject again. // See Note [Destructing py::object] in python_ivalue.h (*type_).ptr() = nullptr; } } c10::intrusive_ptr PyRRef::getFuture() const { // Marking hasValue to false, as this Future is only used for signaling // profiler to update profiling result and the profiler does not retrieve // any value from it. return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */); } c10::intrusive_ptr PyRRef::getProfilingFuture() const { TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!"); return *profilingFuture_; } void PyRRef::setProfilingFuture(c10::intrusive_ptr profilingFuture) { profilingFuture_ = std::move(profilingFuture); } bool PyRRef::isOwner() const { return rref_->isOwner(); } bool PyRRef::confirmedByOwner() const { return rref_->confirmedByOwner(); } WorkerInfo PyRRef::owner() const { return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner()); } std::string PyRRef::ownerName() const { return rref_->ownerName(); } py::object PyRRef::toHere(const float timeoutSeconds) const { if (rref_->isOwner()) { return localValue(); } else { // toHere() calls python_rpc_handler which acquires GIL when UserRRef holds // a python object IValue value = c10::static_intrusive_pointer_cast(rref_)->toHere( timeoutSeconds); if (rref_->isPyObj()) { // python_rpc_handler deserialization will acquires GIL. auto rfr_values = value.toTuple()->elements(); auto& pythonRpcHandler = PythonRpcHandler::getInstance(); auto ret = pythonRpcHandler.deserialize( SerializedPyObj::fromIValues(rfr_values)); pythonRpcHandler.handleException(ret); return ret; } else { // acquiring GIL as torch::jit::toPyObject creates new py::object // without grabbing the GIL. pybind11::gil_scoped_acquire ag; return torch::jit::toPyObject(std::move(value)); } } } py::object PyRRef::localValue() const { TORCH_CHECK( rref_->isOwner(), "For ", *rref_, ", can't call localValue() on user ", RRefContext::getInstance().agent()->getWorkerInfo(), ". Call it on owner ", owner()); py::object res; auto value = c10::static_intrusive_pointer_cast(rref_)->getValue(); auto& rpcHandler = PythonRpcHandler::getInstance(); { // acquiring GIL as torch::jit::toPyObject creates new py::object without // grabbing the GIL. pybind11::gil_scoped_acquire ag; res = torch::jit::toPyObject(std::move(value)); rpcHandler.handleExceptionGILHeld(res); } return res; } std::string PyRRef::str() const { if (rref_->isOwner()) { return c10::str("OwnerRRef(", rref_->rrefId(), ")"); } else { return c10::str( "UserRRef(RRefId = ", rref_->rrefId(), ", ForkId = ", c10::static_intrusive_pointer_cast(rref_)->forkId(), ")"); } } py::object PyRRef::createRRefProxy( const RRefProxyType& type, float timeoutSeconds) const { auto& pythonRpcHandler = PythonRpcHandler::getInstance(); pybind11::gil_scoped_acquire ag; auto& functions = pythonRpcHandler.getRRefProxyFunctions(); auto& ctor = functions.rrefProxyCtor_; switch (type) { case RRefProxyType::RPC_SYNC: { return ctor(*this, functions.rpcSync_, timeoutSeconds); } case RRefProxyType::RPC_ASYNC: { return ctor(*this, functions.rpcAsync_, timeoutSeconds); } case RRefProxyType::REMOTE: { return ctor(*this, functions.remote_, timeoutSeconds); } default: { TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type); } } } py::object PyRRef::getRRefType(float timeout, bool blocking) { // GIL is not released when calling this function. if (!type_.has_value()) { pybind11::gil_scoped_release release; auto& pythonRpcHandler = PythonRpcHandler::getInstance(); auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions(); pybind11::gil_scoped_acquire acquire; type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking) : typeFuncs.onUser_(*this, timeout, blocking); } // Returns py::object that can be Python type or future. return *type_; } py::tuple PyRRef::pickle() const { auto& ctx = RRefContext::getInstance(); auto rrefForkData = ctx.prepareChildFork(rref_); return toPyTuple(rrefForkData); } PyRRef PyRRef::unpickle(const py::tuple& pyTuple) { auto& ctx = RRefContext::getInstance(); auto rrefForkData = fromPyTuple(pyTuple); TypePtr rrefType = PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_); c10::intrusive_ptr rref = ctx.getOrCreateRRef(rrefForkData, rrefType); ctx.notifyOwnerAndParentOfFork( rrefForkData.forkId_, rrefForkData.parent_, rref); return PyRRef(std::move(rref)); } c10::IValue PyRRef::toIValue() const { // cast to RRefInterface to hold it into IValue auto rrefPtr = c10::static_intrusive_pointer_cast(rref_); return IValue(rrefPtr); } void PyRRef::backward(int64_t autogradContextId, bool retainGraph) { backward(autogradContextId, retainGraph, rref_); } void PyRRef::backwardOwnerRRef( int64_t autogradContextId, bool retainGraph, IValue value) { // If we have a PyObj, retrieve the underlying tensor. if (value.isPyObject()) { py::gil_scoped_acquire gil; py::object obj = torch::jit::toPyObject(value); try { value = torch::jit::toIValue(obj, c10::TensorType::get()); } catch (py::cast_error& e) { TORCH_CHECK(false, "RRef should contain a tensor for .backward()"); } } TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()"); auto root = value.toTensor(); if (autogradContextId == -1) { torch::autograd::backward({root}); } else { torch::distributed::autograd::backward( autogradContextId, {root}, retainGraph); } } void PyRRef::backward( int64_t autogradContextId, bool retainGraph, const c10::intrusive_ptr& rref) { if (rref->isOwner()) { backwardOwnerRRef( autogradContextId, retainGraph, c10::static_intrusive_pointer_cast(rref)->getValue()); } else { TORCH_CHECK( autogradContextId != -1, "User RRefs require 'dist_autograd_ctx_id' to be specified"); autograd::RRefBackwardReq rrefBackwardReq( rref->rrefId(), autogradContextId, retainGraph); // Invoke distributed backward remotely. auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent(); rpcAgent ->send( rpcAgent->getWorkerInfo(rref->owner()), std::move(rrefBackwardReq).toMessage()) ->waitAndThrow(); } } } // namespace rpc } // namespace distributed } // namespace torch