mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67722 ghstack-source-id: 142259452 Test Plan: waitforbuildbot Reviewed By: jaceyca, fduwjj Differential Revision: D32118872 fbshipit-source-id: 041ab5601221b1846c56ce4bb63364bec9ad28b0
368 lines
12 KiB
C++
368 lines
12 KiB
C++
#include <torch/csrc/distributed/rpc/py_rref.h>
|
|
|
|
#include <torch/csrc/autograd/autograd.h>
|
|
#include <torch/csrc/distributed/autograd/autograd.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
|
|
#include <torch/csrc/distributed/rpc/python_functions.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/jit/python/module_python.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
///////////////////// Pickle/Unpickle Helplers ////////////////////////////
|
|
|
|
namespace {
|
|
|
|
py::tuple toPyTuple(const RRefForkData& rrefForkData) {
|
|
// add GIL as it is contructing a py::object
|
|
pybind11::gil_scoped_acquire ag;
|
|
return py::make_tuple(
|
|
rrefForkData.ownerId_,
|
|
rrefForkData.rrefId_.createdOn_,
|
|
rrefForkData.rrefId_.localId_,
|
|
rrefForkData.forkId_.createdOn_,
|
|
rrefForkData.forkId_.localId_,
|
|
rrefForkData.parent_,
|
|
rrefForkData.typeStr_);
|
|
}
|
|
|
|
RRefForkData fromPyTuple(const py::tuple& pyTuple) {
|
|
// add GIL as it is accessing a py::object
|
|
pybind11::gil_scoped_acquire ag;
|
|
TORCH_INTERNAL_ASSERT(
|
|
pyTuple.size() == RFD_TUPLE_SIZE,
|
|
"Pickled RRefForkData must contain ",
|
|
RFD_TUPLE_SIZE,
|
|
" numbers.");
|
|
worker_id_t ownerId = pyTuple[OWNER_IDX].cast<worker_id_t>();
|
|
// const reference will extend the lifetime of the temporary variable
|
|
const RRefId& rrefId = RRefId(
|
|
pyTuple[RREFID_ON_IDX].cast<worker_id_t>(),
|
|
pyTuple[RREFID_ID_IDX].cast<local_id_t>());
|
|
const RRefId& forkId = RRefId(
|
|
pyTuple[FORKID_ON_IDX].cast<worker_id_t>(),
|
|
pyTuple[FORKID_ID_IDX].cast<local_id_t>());
|
|
|
|
worker_id_t parent = pyTuple[PARENT_IDX].cast<worker_id_t>();
|
|
const std::string& typeStr = pyTuple[TYPE_IDX].cast<std::string>();
|
|
|
|
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
|
}
|
|
|
|
TypePtr tryInferTypeWithTypeHint(
|
|
const py::object& value,
|
|
const py::object& type_hint) {
|
|
// If the py::object to be contained by the RRef is a ScripModule, we enforce
|
|
// users to specify its ModuleInterface type.
|
|
if (auto module = jit::as_module(value)) {
|
|
TORCH_CHECK(
|
|
!type_hint.is_none(),
|
|
"The RRef being created contains a ScriptModule, "
|
|
"must provide its ModuleInterface type hint. ");
|
|
c10::QualifiedName type_qualified_name = c10::QualifiedName(
|
|
py::cast<std::string>(py::module::import("torch._jit_internal")
|
|
.attr("_qualified_name")(type_hint)));
|
|
TypePtr type_hint_ptr =
|
|
jit::get_python_cu()->get_interface(type_qualified_name);
|
|
std::ostringstream subtype_check_msg;
|
|
TORCH_CHECK(
|
|
type_hint_ptr != nullptr &&
|
|
module.value().type()->isSubtypeOfExt(
|
|
*type_hint_ptr, &subtype_check_msg),
|
|
module.value().type()->repr_str(),
|
|
" is not a subtype of the type hint: ",
|
|
type_qualified_name.qualifiedName(),
|
|
", did you pass a valid interface type?\n",
|
|
subtype_check_msg.str());
|
|
return type_hint_ptr;
|
|
} else {
|
|
TORCH_CHECK(
|
|
type_hint.is_none(),
|
|
"type_hint should only be specified when the RRef being created contains a ScriptModule.");
|
|
}
|
|
|
|
// Check if value is an instance of a ScriptClass. If not, skip type inference
|
|
// because it will try to script the class that value is in instance of, and
|
|
// this should be avoided.
|
|
py::bool_ can_compile = py::module::import("torch._jit_internal")
|
|
.attr("can_compile_class")(value.get_type());
|
|
|
|
if (py::cast<bool>(can_compile)) {
|
|
py::object existing_ty = py::module::import("torch.jit._state")
|
|
.attr("_get_script_class")(value.get_type());
|
|
|
|
if (existing_ty.is_none()) {
|
|
return PyObjectType::get();
|
|
}
|
|
}
|
|
|
|
// NB: `jit::tryToInferType(..)` infers types including ScriptClass, but
|
|
// excluding ScripModule.
|
|
jit::InferredType type_inferred = jit::tryToInferType(value);
|
|
if (type_inferred.success()) {
|
|
// If we could infer the type from the pyobject, we create
|
|
// the RRef with the IValue of that type.
|
|
return type_inferred.type();
|
|
}
|
|
|
|
// Otherwise it's a pure pyobject, create the RRef
|
|
// that holds an IValue of an pyobject.
|
|
return PyObjectType::get();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/////////////////////////// PyRRef //////////////////////////////////
|
|
|
|
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
|
|
: rref_(std::move(rref)), profilingFuture_(c10::nullopt) {
|
|
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
|
|
C10_LOG_API_USAGE_ONCE("torch.distributed.rref");
|
|
}
|
|
|
|
PyRRef::PyRRef(const py::object& value, const py::object& type_hint)
|
|
: PyRRef([&value, &type_hint]() mutable {
|
|
TypePtr elem_type = tryInferTypeWithTypeHint(value, type_hint);
|
|
auto rref = RRefContext::getInstance().createOwnerRRef(elem_type);
|
|
// jit::toIValue takes a py::handle as the first argument, and it calls
|
|
// py::handle.cast<py::object>() to incref of provided value. The
|
|
// returned ivalue will keep the reference alive.
|
|
// NB: the first argument const py::object& value must be kept alive
|
|
// until the following jit::toIValue returns (i.e., incref done). That's
|
|
// why this ctor can only be called while holding GIL.
|
|
IValue ivalue = jit::toIValue(value, elem_type);
|
|
rref->setValue(std::move(ivalue));
|
|
return rref;
|
|
}()) {}
|
|
|
|
PyRRef::~PyRRef() {
|
|
if (type_.has_value()) {
|
|
(*type_).dec_ref();
|
|
// explicitly setting PyObject* to nullptr to prevent py::object's dtor to
|
|
// decref on the PyObject again.
|
|
// See Note [Destructing py::object] in python_ivalue.h
|
|
(*type_).ptr() = nullptr;
|
|
}
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
|
|
// Marking hasValue to false, as this Future is only used for signaling
|
|
// profiler to update profiling result and the profiler does not retrieve
|
|
// any value from it.
|
|
return toPyJitFuture(rref_->getOwnerCreationFuture(), false /* hasValue */);
|
|
}
|
|
|
|
c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
|
|
TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
|
|
return *profilingFuture_;
|
|
}
|
|
|
|
void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
|
|
profilingFuture_ = std::move(profilingFuture);
|
|
}
|
|
|
|
bool PyRRef::isOwner() const {
|
|
return rref_->isOwner();
|
|
}
|
|
|
|
bool PyRRef::confirmedByOwner() const {
|
|
return rref_->confirmedByOwner();
|
|
}
|
|
|
|
WorkerInfo PyRRef::owner() const {
|
|
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
|
|
}
|
|
|
|
std::string PyRRef::ownerName() const {
|
|
return rref_->ownerName();
|
|
}
|
|
|
|
py::object PyRRef::toHere(const float timeoutSeconds) const {
|
|
C10_LOG_API_USAGE_ONCE("torch.distributed.rref.to_here");
|
|
if (rref_->isOwner()) {
|
|
return localValue();
|
|
} else {
|
|
// toHere() calls python_rpc_handler which acquires GIL when UserRRef holds
|
|
// a python object
|
|
IValue value = c10::static_intrusive_pointer_cast<UserRRef>(rref_)->toHere(
|
|
timeoutSeconds);
|
|
|
|
if (rref_->isPyObj()) {
|
|
// python_rpc_handler deserialization will acquires GIL.
|
|
auto rfr_values = value.toTupleRef().elements().vec();
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
auto ret = pythonRpcHandler.deserialize(
|
|
SerializedPyObj::fromIValues(std::move(rfr_values)));
|
|
pythonRpcHandler.handleException(ret);
|
|
return ret;
|
|
} else {
|
|
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
|
// without grabbing the GIL.
|
|
pybind11::gil_scoped_acquire ag;
|
|
return torch::jit::toPyObject(std::move(value));
|
|
}
|
|
}
|
|
}
|
|
|
|
py::object PyRRef::localValue() const {
|
|
TORCH_CHECK(
|
|
rref_->isOwner(),
|
|
"For ",
|
|
*rref_,
|
|
", can't call localValue() on user ",
|
|
RRefContext::getInstance().agent()->getWorkerInfo(),
|
|
". Call it on owner ",
|
|
owner());
|
|
|
|
py::object res;
|
|
auto value =
|
|
c10::static_intrusive_pointer_cast<const OwnerRRef>(rref_)->getValue();
|
|
auto& rpcHandler = PythonRpcHandler::getInstance();
|
|
{
|
|
// acquiring GIL as torch::jit::toPyObject creates new py::object without
|
|
// grabbing the GIL.
|
|
pybind11::gil_scoped_acquire ag;
|
|
res = torch::jit::toPyObject(std::move(value));
|
|
rpcHandler.handleExceptionGILHeld(res);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
std::string PyRRef::str() const {
|
|
if (rref_->isOwner()) {
|
|
return c10::str("OwnerRRef(", rref_->rrefId(), ")");
|
|
} else {
|
|
return c10::str(
|
|
"UserRRef(RRefId = ",
|
|
rref_->rrefId(),
|
|
", ForkId = ",
|
|
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->forkId(),
|
|
")");
|
|
}
|
|
}
|
|
|
|
py::object PyRRef::createRRefProxy(
|
|
const RRefProxyType& type,
|
|
float timeoutSeconds) const {
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
pybind11::gil_scoped_acquire ag;
|
|
auto& functions = pythonRpcHandler.getRRefProxyFunctions();
|
|
auto& ctor = functions.rrefProxyCtor_;
|
|
switch (type) {
|
|
case RRefProxyType::RPC_SYNC: {
|
|
return ctor(*this, functions.rpcSync_, timeoutSeconds);
|
|
}
|
|
case RRefProxyType::RPC_ASYNC: {
|
|
return ctor(*this, functions.rpcAsync_, timeoutSeconds);
|
|
}
|
|
case RRefProxyType::REMOTE: {
|
|
return ctor(*this, functions.remote_, timeoutSeconds);
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(false, "Unrecognized RRefProxy type ", type);
|
|
}
|
|
}
|
|
}
|
|
|
|
py::object PyRRef::getRRefType(float timeout, bool blocking) {
|
|
// GIL is not released when calling this function.
|
|
if (!type_.has_value()) {
|
|
pybind11::gil_scoped_release release;
|
|
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
|
auto& typeFuncs = pythonRpcHandler.getRRefTypeFunctions();
|
|
pybind11::gil_scoped_acquire acquire;
|
|
type_ = isOwner() ? typeFuncs.onOwner_(*this, blocking)
|
|
: typeFuncs.onUser_(*this, timeout, blocking);
|
|
}
|
|
// Returns py::object that can be Python type or future.
|
|
return *type_;
|
|
}
|
|
|
|
py::tuple PyRRef::pickle() const {
|
|
auto& ctx = RRefContext::getInstance();
|
|
auto rrefForkData = ctx.prepareChildFork(rref_);
|
|
return toPyTuple(rrefForkData);
|
|
}
|
|
|
|
PyRRef PyRRef::unpickle(const py::tuple& pyTuple) {
|
|
auto& ctx = RRefContext::getInstance();
|
|
auto rrefForkData = fromPyTuple(pyTuple);
|
|
TypePtr rrefType =
|
|
PythonRpcHandler::getInstance().parseTypeFromStr(rrefForkData.typeStr_);
|
|
c10::intrusive_ptr<RRef> rref = ctx.getOrCreateRRef(rrefForkData, rrefType);
|
|
ctx.notifyOwnerAndParentOfFork(
|
|
rrefForkData.forkId_, rrefForkData.parent_, rref);
|
|
return PyRRef(std::move(rref));
|
|
}
|
|
|
|
c10::IValue PyRRef::toIValue() const {
|
|
// cast to RRefInterface to hold it into IValue
|
|
auto rrefPtr = c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref_);
|
|
return IValue(rrefPtr);
|
|
}
|
|
|
|
void PyRRef::backward(int64_t autogradContextId, bool retainGraph) {
|
|
backward(autogradContextId, retainGraph, rref_);
|
|
}
|
|
|
|
void PyRRef::backwardOwnerRRef(
|
|
int64_t autogradContextId,
|
|
bool retainGraph,
|
|
IValue value) {
|
|
// If we have a PyObj, retrieve the underlying tensor.
|
|
if (value.isPyObject()) {
|
|
py::gil_scoped_acquire gil;
|
|
py::object obj = torch::jit::toPyObject(value);
|
|
try {
|
|
value = torch::jit::toIValue(obj, c10::TensorType::get());
|
|
} catch (py::cast_error& e) {
|
|
TORCH_CHECK(false, "RRef should contain a tensor for .backward()");
|
|
}
|
|
}
|
|
|
|
TORCH_CHECK(value.isTensor(), "RRef should contain a tensor for .backward()");
|
|
auto root = value.toTensor();
|
|
|
|
if (autogradContextId == -1) {
|
|
torch::autograd::backward({root});
|
|
} else {
|
|
torch::distributed::autograd::backward(
|
|
autogradContextId, {root}, retainGraph);
|
|
}
|
|
}
|
|
|
|
void PyRRef::backward(
|
|
int64_t autogradContextId,
|
|
bool retainGraph,
|
|
const c10::intrusive_ptr<RRef>& rref) {
|
|
if (rref->isOwner()) {
|
|
backwardOwnerRRef(
|
|
autogradContextId,
|
|
retainGraph,
|
|
c10::static_intrusive_pointer_cast<const OwnerRRef>(rref)->getValue());
|
|
} else {
|
|
TORCH_CHECK(
|
|
autogradContextId != -1,
|
|
"User RRefs require 'dist_autograd_ctx_id' to be specified");
|
|
|
|
autograd::RRefBackwardReq rrefBackwardReq(
|
|
rref->rrefId(), autogradContextId, retainGraph);
|
|
|
|
// Invoke distributed backward remotely.
|
|
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
|
|
rpcAgent
|
|
->send(
|
|
rpcAgent->getWorkerInfo(rref->owner()),
|
|
std::move(rrefBackwardReq).toMessage())
|
|
->waitAndThrow();
|
|
}
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|