Files
pytorch/torch/csrc/distributed/rpc/py_rref.cpp
Shen Li f99a693cd9 Remove unnecessary py::object copy in PyRRef ctor (#38402)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38402

Test Plan: Imported from OSS

Differential Revision: D21554724

Pulled By: mrshenli

fbshipit-source-id: abab45010810ec53628ea2c7a9c76cdc50eb2f74
2020-05-13 22:00:13 -07:00

250 lines
8.4 KiB
C++

#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
namespace {
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
return py::make_tuple(
rrefForkData.ownerId_,
rrefForkData.rrefId_.createdOn_,
rrefForkData.rrefId_.localId_,
rrefForkData.forkId_.createdOn_,
rrefForkData.forkId_.localId_,
rrefForkData.parent_,
rrefForkData.typeStr_);
}
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
// add GIL as it is accessing a py::object
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
pyTuple.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain ",
RFD_TUPLE_SIZE,
" numbers.");
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
TypePtr tryInferTypeWithTypeHint(
const py::object& value,
const py::object& type_hint) {
// If the py::object to be contained by the RRef is a ScripModule, we enforce
// users to specify its ModuleInterface type.
if (auto module = jit::as_module(value)) {
TORCH_CHECK(
!type_hint.is_none(),
"The RRef being created contains a ScriptModule, "
"must provide its ModuleInterface type hint. ");
c10::QualifiedName type_qualified_name = c10::QualifiedName(
py::cast<std::string>(py::module::import("torch.jit")
.attr("_qualified_name")(type_hint)));
TypePtr type_hint_ptr =
jit::get_python_cu()->get_interface(type_qualified_name);
std::ostringstream subtype_check_msg;
TORCH_CHECK(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOfExt(
type_hint_ptr, &subtype_check_msg),
module.value().type()->python_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),
", did you pass a valid interface type?\n",
subtype_check_msg.str());
return type_hint_ptr;
} else {
TORCH_CHECK(
type_hint.is_none(),
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
}
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
// excluding ScripModule.
jit::InferredType type_inferred = jit::tryToInferType(value);
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
return type_inferred.type();
}
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject.
return PyObjectType::get();
}
} // namespace
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
: PyRRef([&value, &type_hint]() {
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
// jit::toIValue takes a py::handle as the first argument, and it calls
// py::handle.cast<py::object>() to incref of provided value. The
// returned ivalue will keep the reference alive.
// NB: the first argument const py::object& value must be kept alive
// until the following jit::toIValue returns (i.e., incref done). That's
// why this ctor can only be called while holding GIL.
IValue ivalue = jit::toIValue(value, elem_type);
rref->setValue(std::move(ivalue));
return rref;
}()) {}
c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
// Marking hasValue to false, as this Future is only used for signaling
// profiler to update profiling result and the profiler does not retrieve
// any value from it.
return wrapFutureMessageInJitFuture(
rref_->getOwnerCreationFuture(), false /* hasValue */);
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}
bool PyRRef::confirmedByOwner() const {
return rref_->confirmedByOwner();
}
WorkerInfo PyRRef::owner() const {
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
}
std::string PyRRef::ownerName() const {
return rref_->ownerName();
}
py::object PyRRef::toHere() {
if (rref_->isOwner()) {
return localValue();
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto ret = pythonRpcHandler.deserialize(
SerializedPyObj::fromIValues(rfr_values));
pythonRpcHandler.handleException(ret);
return ret;
} else {
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(std::move(value));
}
}
}
py::object PyRRef::localValue() {
TORCH_CHECK(
rref_->isOwner(),
"Cannot call localValue() on a non-local reference. Call it on ",
owner().name_);
py::object res;
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
auto& rpcHandler = PythonRpcHandler::getInstance();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
// grabbing the GIL.
pybind11::gil_scoped_acquire ag;
res = torch::jit::toPyObject(std::move(value));
rpcHandler.handleExceptionGILHeld(res);
}
return res;
}
std::string PyRRef::str() const {
if (rref_->isOwner()) {
return c10::str("OwnerRRef(", rref_->rrefId(), ")");
} else {
return c10::str(
"UserRRef(RRefId = ",
rref_->rrefId(),
", ForkId = ",
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
")");
}
}
py::object PyRRef::createRRefProxy(const RRefProxyType& type) const {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
pybind11::gil_scoped_acquire ag;
auto& functions = pythonRpcHandler.getRRefProxyFunctions();
auto& ctor = functions.rrefProxyCtor_;
switch (type) {
case RRefProxyType::RPC_SYNC: {
return ctor(*this, functions.rpcSync_);
}
case RRefProxyType::RPC_ASYNC: {
return ctor(*this, functions.rpcAsync_);
}
case RRefProxyType::REMOTE: {
return ctor(*this, functions.remote_);
}
default: {
TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
}
}
}
py::tuple PyRRef::pickle() const {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = ctx.prepareChildFork(rref_);
return toPyTuple(rrefForkData);
}
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = fromPyTuple(pyTuple);
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));
}
c10::IValue PyRRef::toIValue() {
// cast to RRefInterface to hold it into IValue
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
return IValue(rrefPtr);
}
} // namespace rpc
} // namespace distributed
} // namespace torch