mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37519 closes #37446 Currently FutureMessage is used in several places: 1. `rpc_async` returns a `FutureMessage` object and we expose it as `torch.distributed.rpc.Future`. From applications perspective, they are expecting a `py::object` instead of a `Message`, and we do the conversion in the `Future.wait()` pybind method. 2. RPC autograd profiler takes `FutureMessage` and installs callbacks to it. The profiler actually only need a `Future<T>` and does not care what `T` is. 3. `OwnerRRef` exposes a `getFuture()` API which returns a `FutureMessage`. This `FutureMessage` will be marked completed when the value referenced by the `OwnerRRef` is ready. `OwnerRRef` does not need it to be a Message type either, it actually creates an empty `Message` to mark the `Future`. The above places are using `FutureMessage`, but they don't really need a `Message`, and `Message` is a communication layer type that applications or profiler or the RRef shouldn't be aware of. Another motivation for making this change is that for async RPC UDF #36071, we are going to allow application to call `markCompleted` in Python. If we still use `FutureMessage`, then in the `markCompleted` pybind function, it needs to convert the provided `py::object` into a specific message type, which is leaking communication layer code to pybind functions. Even if this is doable, we will have two entities (RPC agent and pybind Python frontend) accessing the same request callback logic. This is too messy. This commit replaces all surface `FutureMessage` with `FutureIValue`, so that `FutureMessage` is no longer visible from Python land. Note that this does not cause BC issues, as the Python Future type name and its API stay intact. Internally, we still have `FutureMessage` in the communication layer. Test Plan: Imported from OSS Reviewed By: xush6528 Differential Revision: D21308887 Pulled By: mrshenli fbshipit-source-id: 4f574f38e83125081f142813cfdde56119522089
244 lines
8.0 KiB
C++
244 lines
8.0 KiB
C++
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
|
|
#include <torch/csrc/distributed/rpc/python_functions.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/jit/python/module_python.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
|
|
|
|
namespace {
|
|
|
|
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
|
|
// add GIL as it is contructing a py::object
|
|
pybind11::gil_scoped_acquire ag;
|
|
return py::make_tuple(
|
|
rrefForkData.ownerId_,
|
|
rrefForkData.rrefId_.createdOn_,
|
|
rrefForkData.rrefId_.localId_,
|
|
rrefForkData.forkId_.createdOn_,
|
|
rrefForkData.forkId_.localId_,
|
|
rrefForkData.parent_,
|
|
rrefForkData.typeStr_);
|
|
}
|
|
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
|
// add GIL as it is accessing a py::object
|
|
pybind11::gil_scoped_acquire ag;
|
|
TORCH_INTERNAL_ASSERT(
|
|
pyTuple.size() == RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain ",
|
|
RFD_TUPLE_SIZE,
|
|
" numbers.");
|
|
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const RRefId& rrefId = RRefId(
|
|
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
|
|
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
|
|
const RRefId& forkId = RRefId(
|
|
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
|
|
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
|
|
|
|
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
|
|
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
|
|
|
|
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
|
}
|
|
|
|
TypePtr tryInferTypeWithTypeHint(
|
|
const py::object& value,
|
|
const py::object& type_hint) {
|
|
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
|
// users to specify its ModuleInterface type.
|
|
if (auto module = jit::as_module(value)) {
|
|
TORCH_CHECK(
|
|
!type_hint.is_none(),
|
|
"The RRef being created contains a ScriptModule, "
|
|
"must provide its ModuleInterface type hint. ");
|
|
c10::QualifiedName type_qualified_name = c10::QualifiedName(
|
|
py::cast<std::string>(py::module::import("torch.jit")
|
|
.attr("_qualified_name")(type_hint)));
|
|
TypePtr type_hint_ptr =
|
|
jit::get_python_cu()->get_interface(type_qualified_name);
|
|
std::ostringstream subtype_check_msg;
|
|
TORCH_CHECK(
|
|
type_hint_ptr != nullptr &&
|
|
module.value().type()->isSubtypeOfExt(
|
|
type_hint_ptr, &subtype_check_msg),
|
|
module.value().type()->python_str(),
|
|
" is not a subtype of the type hint: ",
|
|
type_qualified_name.qualifiedName(),
|
|
", did you pass a valid interface type?\n",
|
|
subtype_check_msg.str());
|
|
return type_hint_ptr;
|
|
} else {
|
|
TORCH_CHECK(
|
|
type_hint.is_none(),
|
|
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
|
|
}
|
|
|
|
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
|
|
// excluding ScripModule.
|
|
jit::InferredType type_inferred = jit::tryToInferType(value);
|
|
if (type_inferred.success()) {
|
|
// If we could infer the type from the pyobject, we create
|
|
// the RRef with the IValue of that type.
|
|
return type_inferred.type();
|
|
}
|
|
|
|
// Otherwise it's a pure pyobject, create the RRef
|
|
// that holds an IValue of an pyobject.
|
|
return PyObjectType::get();
|
|
} // namespace
|
|
|
|
} // namespace
|
|
|
|
/////////////////////////// PyRRef //////////////////////////////////
|
|
|
|
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
|
|
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
|
}
|
|
|
|
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
|
|
: PyRRef([&value, &type_hint]() {
|
|
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
|
|
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
|
|
py::object copy(value); // increases refcount
|
|
IValue ivalue = jit::toIValue(std::move(copy), elem_type);
|
|
rref->setValue(std::move(ivalue));
|
|
return rref;
|
|
}()) {}
|
|
|
|
const std::shared_ptr<FutureIValue> PyRRef::getFuture() const {
|
|
// Marking hasValue to false, as this Future is only used for signaling
|
|
// profiler to update profiling result and the profiler does not retrieve
|
|
// any value from it.
|
|
return toFutureIValue(rref_->getOwnerCreationFuture(), false /* hasValue */);
|
|
}
|
|
|
|
bool PyRRef::isOwner() const {
|
|
return rref_->isOwner();
|
|
}
|
|
|
|
bool PyRRef::confirmedByOwner() const {
|
|
return rref_->confirmedByOwner();
|
|
}
|
|
|
|
WorkerInfo PyRRef::owner() const {
|
|
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
|
|
}
|
|
|
|
std::string PyRRef::ownerName() const {
|
|
return rref_->ownerName();
|
|
}
|
|
|
|
py::object PyRRef::toHere() {
|
|
if (rref_->isOwner()) {
|
|
return localValue();
|
|
} else {
|
|
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
|
|
// a python object
|
|
IValue value =
|
|
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere();
|
|
if (rref_->isPyObj()) {
|
|
// python_rpc_handler deserialization will acquires GIL.
|
|
auto rfr_values = value.toTuple()->elements();
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
auto ret = pythonRpcHandler.deserialize(
|
|
SerializedPyObj::fromIValues(rfr_values));
|
|
pythonRpcHandler.handleException(ret);
|
|
return ret;
|
|
} else {
|
|
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
|
// without grabbing the GIL.
|
|
pybind11::gil_scoped_acquire ag;
|
|
return torch::jit::toPyObject(std::move(value));
|
|
}
|
|
}
|
|
}
|
|
|
|
py::object PyRRef::localValue() {
|
|
TORCH_CHECK(
|
|
rref_->isOwner(),
|
|
"Cannot call localValue() on a non-local reference. Call it on ",
|
|
owner().name_);
|
|
|
|
py::object res;
|
|
auto value = c10::static_intrusive_pointer_cast<OwnerRRef>(rref_)->getValue();
|
|
auto& rpcHandler = PythonRpcHandler::getInstance();
|
|
{
|
|
// acquiring GIL as torch::jit::toPyObject creates new py::object without
|
|
// grabbing the GIL.
|
|
pybind11::gil_scoped_acquire ag;
|
|
res = torch::jit::toPyObject(std::move(value));
|
|
rpcHandler.handleExceptionGILHeld(res);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
std::string PyRRef::str() const {
|
|
if (rref_->isOwner()) {
|
|
return c10::str("OwnerRRef(", rref_->rrefId(), ")");
|
|
} else {
|
|
return c10::str(
|
|
"UserRRef(RRefId = ",
|
|
rref_->rrefId(),
|
|
", ForkId = ",
|
|
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
|
|
")");
|
|
}
|
|
}
|
|
|
|
py::object PyRRef::createRRefProxy(PyRRef& self, const RRefProxyType& type)
|
|
const {
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
pybind11::gil_scoped_acquire ag;
|
|
auto& functions = pythonRpcHandler.getRRefProxyFunctions();
|
|
auto& ctor = functions.rrefProxyCtor_;
|
|
switch (type) {
|
|
case RRefProxyType::RPC_SYNC: {
|
|
return ctor(self, functions.rpcSync_);
|
|
}
|
|
case RRefProxyType::RPC_ASYNC: {
|
|
return ctor(self, functions.rpcAsync_);
|
|
}
|
|
case RRefProxyType::REMOTE: {
|
|
return ctor(self, functions.remote_);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
|
|
}
|
|
}
|
|
}
|
|
|
|
py::tuple PyRRef::pickle() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
auto rrefForkData = ctx.prepareChildFork(rref_);
|
|
return toPyTuple(rrefForkData);
|
|
}
|
|
|
|
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
|
auto& ctx = RRefContext::getInstance();
|
|
auto rrefForkData = fromPyTuple(pyTuple);
|
|
TypePtr rrefType =
|
|
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
|
|
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
|
ctx.notifyOwnerAndParentOfFork(
|
|
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
|
return PyRRef(std::move(rref));
|
|
}
|
|
|
|
c10::IValue PyRRef::toIValue() {
|
|
// cast to RRefInterface to hold it into IValue
|
|
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
|
|
return IValue(rrefPtr);
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|