Files
pytorch/torch/csrc/distributed/rpc/py_rref.cpp
Shen Li d5b38984c8 Let RPC return FutureIValue instead of FutureMessage (#37519)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37519

closes #37446

Currently FutureMessage is used in several places:

1. `rpc_async` returns a `FutureMessage` object and we expose it
   as `torch.distributed.rpc.Future`. From applications perspective,
   they are expecting a `py::object` instead of a `Message`, and we
   do the conversion in the `Future.wait()` pybind method.
2. RPC autograd profiler takes `FutureMessage` and installs
   callbacks to it. The profiler actually only need a `Future<T>`
   and does not care what `T` is.
3. `OwnerRRef` exposes a `getFuture()` API which returns a
   `FutureMessage`. This `FutureMessage` will be marked completed
   when the value referenced by the `OwnerRRef` is ready.
   `OwnerRRef` does not need it to be a Message type either, it
   actually creates an empty `Message` to mark the `Future`.

The above places are using `FutureMessage`, but they don't really
need a `Message`, and `Message` is a communication layer type that
applications or profiler or the RRef shouldn't be aware of.

Another motivation for making this change is that for async RPC
UDF #36071, we are going to allow application to call
`markCompleted` in Python. If we still use `FutureMessage`, then
in the `markCompleted` pybind function, it needs to convert the
provided `py::object` into a specific message type, which is
leaking communication layer code to pybind functions. Even if
this is doable, we will have two entities (RPC agent and pybind
Python frontend) accessing the same request callback logic. This is too messy.

This commit replaces all surface `FutureMessage` with `FutureIValue`,
so that `FutureMessage` is no longer visible from Python land. Note
that this does not cause BC issues, as the Python Future type name
and its API stay intact. Internally, we still have `FutureMessage`
in the communication layer.

Test Plan: Imported from OSS

Reviewed By: xush6528

Differential Revision: D21308887

Pulled By: mrshenli

fbshipit-source-id: 4f574f38e83125081f142813cfdde56119522089
2020-04-29 19:10:29 -07:00

244 lines
8.0 KiB
C++

#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
namespace {
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
// add GIL as it is contructing a py::object
pybind11::gil_scoped_acquire ag;
return py::make_tuple(
rrefForkData.ownerId_,
rrefForkData.rrefId_.createdOn_,
rrefForkData.rrefId_.localId_,
rrefForkData.forkId_.createdOn_,
rrefForkData.forkId_.localId_,
rrefForkData.parent_,
rrefForkData.typeStr_);
}
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
// add GIL as it is accessing a py::object
pybind11::gil_scoped_acquire ag;
TORCH_INTERNAL_ASSERT(
pyTuple.size() == RFD_TUPLE_SIZE,
"Pickled RRefForkData must contain ",
RFD_TUPLE_SIZE,
" numbers.");
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
// const reference will extend the lifetime of the temporary variable
const RRefId& rrefId = RRefId(
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
const RRefId& forkId = RRefId(
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
}
TypePtr tryInferTypeWithTypeHint(
const py::object& value,
const py::object& type_hint) {
// If the py::object to be contained by the RRef is a ScripModule, we enforce
// users to specify its ModuleInterface type.
if (auto module = jit::as_module(value)) {
TORCH_CHECK(
!type_hint.is_none(),
"The RRef being created contains a ScriptModule, "
"must provide its ModuleInterface type hint. ");
c10::QualifiedName type_qualified_name = c10::QualifiedName(
py::cast<std::string>(py::module::import("torch.jit")
.attr("_qualified_name")(type_hint)));
TypePtr type_hint_ptr =
jit::get_python_cu()->get_interface(type_qualified_name);
std::ostringstream subtype_check_msg;
TORCH_CHECK(
type_hint_ptr != nullptr &&
module.value().type()->isSubtypeOfExt(
type_hint_ptr, &subtype_check_msg),
module.value().type()->python_str(),
" is not a subtype of the type hint: ",
type_qualified_name.qualifiedName(),
", did you pass a valid interface type?\n",
subtype_check_msg.str());
return type_hint_ptr;
} else {
TORCH_CHECK(
type_hint.is_none(),
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
}
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
// excluding ScripModule.
jit::InferredType type_inferred = jit::tryToInferType(value);
if (type_inferred.success()) {
// If we could infer the type from the pyobject, we create
// the RRef with the IValue of that type.
return type_inferred.type();
}
// Otherwise it's a pure pyobject, create the RRef
// that holds an IValue of an pyobject.
return PyObjectType::get();
} // namespace
} // namespace
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
: PyRRef([&value, &type_hint]() {
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
py::object copy(value); // increases refcount
IValue ivalue = jit::toIValue(std::move(copy), elem_type);
rref->setValue(std::move(ivalue));
return rref;
}()) {}
const std::shared_ptr<FutureIValue> PyRRef::getFuture() const {
// Marking hasValue to false, as this Future is only used for signaling
// profiler to update profiling result and the profiler does not retrieve
// any value from it.
return toFutureIValue(rref_->getOwnerCreationFuture(), false /* hasValue */);
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}
bool PyRRef::confirmedByOwner() const {
return rref_->confirmedByOwner();
}
WorkerInfo PyRRef::owner() const {
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
}
std::string PyRRef::ownerName() const {
return rref_->ownerName();
}
py::object PyRRef::toHere() {
if (rref_->isOwner()) {
return localValue();
} else {
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
// a python object
IValue value =
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
if (rref_->isPyObj()) {
// python_rpc_handler deserialization will acquires GIL.
auto rfr_values = value.toTuple()->elements();
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
auto ret = pythonRpcHandler.deserialize(
SerializedPyObj::fromIValues(rfr_values));
pythonRpcHandler.handleException(ret);
return ret;
} else {
// acquiring GIL as torch::jit::toPyObject creates new py::object
// without grabbing the GIL.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(std::move(value));
}
}
}
py::object PyRRef::localValue() {
TORCH_CHECK(
rref_->isOwner(),
"Cannot call localValue() on a non-local reference. Call it on ",
owner().name_);
py::object res;
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
auto& rpcHandler = PythonRpcHandler::getInstance();
{
// acquiring GIL as torch::jit::toPyObject creates new py::object without
// grabbing the GIL.
pybind11::gil_scoped_acquire ag;
res = torch::jit::toPyObject(std::move(value));
rpcHandler.handleExceptionGILHeld(res);
}
return res;
}
std::string PyRRef::str() const {
if (rref_->isOwner()) {
return c10::str("OwnerRRef(", rref_->rrefId(), ")");
} else {
return c10::str(
"UserRRef(RRefId = ",
rref_->rrefId(),
", ForkId = ",
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
")");
}
}
py::object PyRRef::createRRefProxy(PyRRef& self, const RRefProxyType& type)
const {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
pybind11::gil_scoped_acquire ag;
auto& functions = pythonRpcHandler.getRRefProxyFunctions();
auto& ctor = functions.rrefProxyCtor_;
switch (type) {
case RRefProxyType::RPC_SYNC: {
return ctor(self, functions.rpcSync_);
}
case RRefProxyType::RPC_ASYNC: {
return ctor(self, functions.rpcAsync_);
}
case RRefProxyType::REMOTE: {
return ctor(self, functions.remote_);
}
default: {
TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
}
}
}
py::tuple PyRRef::pickle() const {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = ctx.prepareChildFork(rref_);
return toPyTuple(rrefForkData);
}
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
auto& ctx = RRefContext::getInstance();
auto rrefForkData = fromPyTuple(pyTuple);
TypePtr rrefType =
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
ctx.notifyOwnerAndParentOfFork(
rrefForkData.forkId_, rrefForkData.parent_, rref);
return PyRRef(std::move(rref));
}
c10::IValue PyRRef::toIValue() {
// cast to RRefInterface to hold it into IValue
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
return IValue(rrefPtr);
}
} // namespace rpc
} // namespace distributed
} // namespace torch