[rpc] Remove template on RRef and add Type to RRef creation (#30630)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630

This remove template and all the specializations it have in rpc, we
universally use IValue as the inner value since we support making python
object to be hold inside IValue.

This will also ensure that we have the correct type information when
creating the RRef, we use the return type from the schema when creating
userRRef and OwnerRRef, it will enable IValue to always have the correct
type if the IValue is the RRef object (next PR)

Test Plan: Imported from OSS

Differential Revision: D19502235

fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
This commit is contained in:
Yanli Zhao
2020-01-23 21:09:23 -08:00
committed by Facebook Github Bot
parent ef2d4e67d1
commit b474c351dd
9 changed files with 219 additions and 238 deletions

View File

@ -8,21 +8,6 @@
namespace torch { namespace torch {
namespace distributed { namespace distributed {
namespace rpc { namespace rpc {
namespace {
// Constants below are used in PyRRef pickling and unpickling. PyRRef is
// converted into a py::tuple in pickling, and reconstructed from the py::tuple
// in unpickling.
constexpr int RFD_IDX = 0; // index of RRefForkData
constexpr int TYPE_IDX = 1; // index of type (py::object or IValue)
// number of data fields in the py::tuple.
// NB: if more fields are added, make sure this field is also bumped
constexpr int RREF_TUPLE_SIZE = 2;
} // namespace
/////////////////////////// PyRRef ////////////////////////////////// /////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) { PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
@ -31,9 +16,11 @@ PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(const py::object& value) PyRRef::PyRRef(const py::object& value)
: PyRRef([&value]() { : PyRRef([&value]() {
auto rref = RRefContext::getInstance().createOwnerRRef<py::object>(); auto rref =
RRefContext::getInstance().createOwnerRRef(PyObjectType::get());
py::object copy(value); // increases refcount py::object copy(value); // increases refcount
rref->setValue(std::move(copy)); IValue py_ivalue = jit::toIValue(std::move(copy), PyObjectType::get());
rref->setValue(std::move(py_ivalue));
return rref; return rref;
}()) {} }()) {}
@ -52,10 +39,10 @@ py::object PyRRef::toHere() {
if (rref_->isPyObj()) { if (rref_->isPyObj()) {
// UserRRef<py::object>::toHere() calls python_rpc_handler which acquires // UserRRef<py::object>::toHere() calls python_rpc_handler which acquires
// GIL. // GIL.
return std::static_pointer_cast<UserRRef<py::object>>(rref_)->toHere(); return jit::toPyObject(
std::static_pointer_cast<UserRRef>(rref_)->toHere());
} else { } else {
IValue value = IValue value = std::static_pointer_cast<UserRRef>(rref_)->toHere();
std::static_pointer_cast<UserRRef<IValue>>(rref_)->toHere();
{ {
// acquiring GIL as torch::jit::toPyObject creates new py::object // acquiring GIL as torch::jit::toPyObject creates new py::object
@ -74,9 +61,8 @@ py::object PyRRef::localValue() {
owner().name_); owner().name_);
if (rref_->isPyObj()) { if (rref_->isPyObj()) {
const py::object& value = const py::object& value = jit::toPyObject(
std::dynamic_pointer_cast<OwnerRRef<py::object>>(rref_)->getValue(); std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue());
PythonRpcHandler::getInstance().handleException(value); PythonRpcHandler::getInstance().handleException(value);
{ {
// acquiring GIL as the return statement construct a new py::object from // acquiring GIL as the return statement construct a new py::object from
@ -85,8 +71,7 @@ py::object PyRRef::localValue() {
return value; return value;
} }
} else { } else {
auto value = auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
std::dynamic_pointer_cast<OwnerRRef<IValue>>(rref_)->getValue();
{ {
// acquiring GIL as torch::jit::toPyObject creates new py::object without // acquiring GIL as torch::jit::toPyObject creates new py::object without
// grabbing the GIL. // grabbing the GIL.
@ -101,13 +86,9 @@ std::string PyRRef::str() const {
if (rref_->isOwner()) { if (rref_->isOwner()) {
ss << "OwnerRRef(" << rref_->rrefId() << ")"; ss << "OwnerRRef(" << rref_->rrefId() << ")";
} else { } else {
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = "; ss << "UserRRef(RRefId = " << rref_->rrefId()
if (rref_->isPyObj()) { << ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
ss << std::static_pointer_cast<UserRRef<py::object>>(rref_)->forkId(); << ")";
} else {
ss << std::static_pointer_cast<UserRRef<IValue>>(rref_)->forkId();
}
ss << ")";
} }
return ss.str(); return ss.str();
} }
@ -119,21 +100,16 @@ py::tuple PyRRef::pickle() const {
// a counter example, checkpointing a model with RRefs should not trigger // a counter example, checkpointing a model with RRefs should not trigger
// forks to be added as a fork or a child. // forks to be added as a fork or a child.
auto rfd = ctx.prepareChildFork(rref_); auto rfd = ctx.prepareChildFork(rref_);
return py::make_tuple(rfd.toPyTuple(), rref_->isPyObj()); return rfd.toPyTuple();
} }
PyRRef PyRRef::unpickle(const py::tuple& t) { PyRRef PyRRef::unpickle(const py::tuple& t) {
TORCH_INTERNAL_ASSERT(
t.size() == RREF_TUPLE_SIZE, "Pickled RRef must contain 2 numbers.");
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
auto rfd = RRefForkData::fromPyTuple(t[RFD_IDX].cast<py::tuple>()); auto rfd = RRefForkData::fromPyTuple(t.cast<py::tuple>());
std::shared_ptr<RRef> rref = nullptr; std::shared_ptr<RRef> rref = nullptr;
bool isPyObj = t[TYPE_IDX].cast<bool>(); TypePtr rref_type =
if (isPyObj) { PythonRpcHandler::getInstance().parseTypeFromStr(rfd.type_str_);
rref = ctx.getOrCreateRRef<py::object>(rfd); rref = ctx.getOrCreateRRef(rfd, rref_type);
} else {
rref = ctx.getOrCreateRRef<IValue>(rfd);
}
ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref); ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
return PyRRef(std::move(rref)); return PyRRef(std::move(rref));

View File

@ -159,13 +159,14 @@ PyRRef pyRemoteBuiltin(
const py::kwargs& kwargs) { const py::kwargs& kwargs) {
Stack stack; Stack stack;
auto op = matchBuiltinOp(opName, args, kwargs, stack); auto op = matchBuiltinOp(opName, args, kwargs, stack);
TypePtr ret_type = op->schema().returns()[0].type();
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
// TODO: support creating RRefs on a local object. // TODO: support creating RRefs on a local object.
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
ctx.getWorkerId() != dst.id_, ctx.getWorkerId() != dst.id_,
"Does not support creating RRef on self yet."); "Does not support creating RRef on self yet.");
auto userRRef = ctx.createUserRRef<IValue>(dst.id_); auto userRRef = ctx.createUserRRef(dst.id_, ret_type);
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
@ -205,7 +206,7 @@ PyRRef pyRemotePythonUdf(
auto serializedPyObj = auto serializedPyObj =
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors)); SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
if (ctx.getWorkerId() != dst.id_) { if (ctx.getWorkerId() != dst.id_) {
auto userRRef = ctx.createUserRRef<py::object>(dst.id_); auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
ctx.addPendingUser(userRRef->forkId(), userRRef); ctx.addPendingUser(userRRef->forkId(), userRRef);
auto fm = sendPythonRemoteCall( auto fm = sendPythonRemoteCall(
agent, agent,
@ -218,7 +219,7 @@ PyRRef pyRemotePythonUdf(
fm->addCallback(finishAcceptUserRRef); fm->addCallback(finishAcceptUserRRef);
return PyRRef(userRRef); return PyRRef(userRRef);
} else { } else {
auto ownerRRef = ctx.createOwnerRRef<py::object>(); auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
// prevent this owner RRef be deleted due to other forks // prevent this owner RRef be deleted due to other forks
ctx.addSelfAsFork(ownerRRef); ctx.addSelfAsFork(ownerRRef);
auto fm = sendPythonRemoteCall( auto fm = sendPythonRemoteCall(

View File

@ -8,6 +8,28 @@ namespace rpc {
namespace { namespace {
// PythonTypeResolver that inherits from Script::Resolver to
// support resolving types together with ScriptTypeParser.
struct PythonTypeResolver : public jit::script::Resolver {
std::shared_ptr<jit::script::SugaredValue> resolveValue(
const std::string& /* unused */,
Function& /* unused */,
const jit::SourceRange& /* unused */) override {
TORCH_INTERNAL_ASSERT(
false, "RPC Type resolver does not need to resolve value");
}
TypePtr resolveType(
const std::string& name,
const jit::SourceRange& /* unused */) override {
if (name == "PyObject") {
return PyObjectType::get();
}
auto python_cu = torch::jit::get_python_cu();
return python_cu->get_type(name);
}
};
py::object getFunction(const py::object& module, const char* name) { py::object getFunction(const py::object& module, const char* name) {
py::object fn = module.attr(name); py::object fn = module.attr(name);
TORCH_CHECK( TORCH_CHECK(
@ -28,6 +50,8 @@ PythonRpcHandler::PythonRpcHandler() {
pySerialize_ = getFunction(module, "serialize"); pySerialize_ = getFunction(module, "serialize");
pyHandleException_ = getFunction(module, "_handle_exception"); pyHandleException_ = getFunction(module, "_handle_exception");
jitCompilationUnit_ = torch::jit::get_python_cu(); jitCompilationUnit_ = torch::jit::get_python_cu();
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
std::make_shared<PythonTypeResolver>());
} }
void PythonRpcHandler::cleanup() { void PythonRpcHandler::cleanup() {
@ -95,6 +119,10 @@ void PythonRpcHandler::handleException(const py::object& obj) {
pyHandleException_(obj); pyHandleException_(obj);
} }
TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
return typeParser_->parseType(type_str);
}
} // namespace rpc } // namespace rpc
} // namespace distributed } // namespace distributed
} // namespace torch } // namespace torch

View File

@ -2,6 +2,7 @@
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/types.h> #include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/script/script_type_parser.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
namespace torch { namespace torch {
@ -59,6 +60,17 @@ class PYBIND11_EXPORT PythonRpcHandler {
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit(); std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit();
// Parse the string to recover the jit_type, this is used for RRef python
// pickling/unpickling type recovery. The type string inference rule is as
// follows:
// 1. first try to parse if this is primitive types.
// i.e. TensorType, IntType, PyObjectType, etc.
// 2. if not primitive type, we query the python_cu to see if it is a
// class type or interface type registered in python
// We use a ScriptTypeParser instance with custom PythonTypeResolver
// to resolve types according to the above rules.
TypePtr parseTypeFromStr(const std::string& typeStr);
private: private:
PythonRpcHandler(); PythonRpcHandler();
~PythonRpcHandler() = default; ~PythonRpcHandler() = default;
@ -102,6 +114,10 @@ class PYBIND11_EXPORT PythonRpcHandler {
// We import the compilation unit here only once for less cost and thread // We import the compilation unit here only once for less cost and thread
// safety. // safety.
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_; std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_;
// jit type parser to parse type_str back to TypePtr for RRef type
// recovery when pickling and unpickling RRef
std::shared_ptr<jit::script::ScriptTypeParser> typeParser_;
}; };
} // namespace rpc } // namespace rpc

View File

@ -21,6 +21,7 @@
#include <torch/csrc/distributed/rpc/script_remote_call.h> #include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h> #include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h> #include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch { namespace torch {
namespace distributed { namespace distributed {
@ -82,7 +83,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
auto& src = static_cast<ScriptRemoteCall&>(rpc); auto& src = static_cast<ScriptRemoteCall&>(rpc);
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx.getOrCreateOwnerRRef<IValue>(src.retRRefId()); TypePtr ret_type = src.op()->schema().returns()[0].type();
auto ownerRRef = ctx.getOrCreateOwnerRRef(src.retRRefId(), ret_type);
// TODO: make this asynchronous // TODO: make this asynchronous
// src is only alive within this block, use reference to avoid copy // src is only alive within this block, use reference to avoid copy
@ -106,10 +108,13 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
auto forkId = ForkId::fromIValue(prc.retForkId()); auto forkId = ForkId::fromIValue(prc.retForkId());
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx.getOrCreateOwnerRRef<py::object>(rrefId); auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get());
ownerRRef->setValue( IValue py_ivalue = jit::toIValue(
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj())); PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()),
PyObjectType::get());
ownerRRef->setValue(std::move(py_ivalue));
if (rrefId != forkId) { if (rrefId != forkId) {
// Caller is a user and callee is the owner, add fork // Caller is a user and callee is the owner, add fork
@ -127,8 +132,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
case MessageType::SCRIPT_RREF_FETCH_CALL: { case MessageType::SCRIPT_RREF_FETCH_CALL: {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc); auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
std::shared_ptr<OwnerRRef<IValue>> rref = std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
ctx.getOwnerRRef<IValue>(srf.rrefId());
if (rref->hasValue()) { // optional fast-path if (rref->hasValue()) { // optional fast-path
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage()); return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
} }
@ -149,11 +153,10 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
case MessageType::PYTHON_RREF_FETCH_CALL: { case MessageType::PYTHON_RREF_FETCH_CALL: {
auto& prf = static_cast<PythonRRefFetchCall&>(rpc); auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
std::shared_ptr<OwnerRRef<py::object>> rref = std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(prf.rrefId());
ctx.getOwnerRRef<py::object>(prf.rrefId());
if (rref->hasValue()) { // optional fast-path if (rref->hasValue()) { // optional fast-path
SerializedPyObj result = SerializedPyObj result = PythonRpcHandler::getInstance().serialize(
PythonRpcHandler::getInstance().serialize(rref->getValue()); jit::toPyObject(rref->getValue()));
return wrap(PythonRRefFetchRet(result.toIValues()).toMessage()); return wrap(PythonRRefFetchRet(result.toIValues()).toMessage());
} }
@ -165,8 +168,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
[responseFuture, messageId, rref]( [responseFuture, messageId, rref](
const rpc::Message& /* unused */, const rpc::Message& /* unused */,
const c10::optional<utils::FutureError>& /* unused */) { const c10::optional<utils::FutureError>& /* unused */) {
SerializedPyObj result = SerializedPyObj result = PythonRpcHandler::getInstance().serialize(
PythonRpcHandler::getInstance().serialize(rref->getValue()); jit::toPyObject(rref->getValue()));
Message m = PythonRRefFetchRet(result.toIValues()).toMessage(); Message m = PythonRRefFetchRet(result.toIValues()).toMessage();
m.setId(messageId); m.setId(messageId);
responseFuture->markCompleted(m); responseFuture->markCompleted(m);

View File

@ -83,28 +83,23 @@ void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
} }
} }
template <typename T> std::shared_ptr<UserRRef> RRefContext::createUserRRef(
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(worker_id_t ownerId) { worker_id_t ownerId,
const TypePtr& type) {
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner."); TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
// Explicitly creating rrefId before forkId to make sure the order is // Explicitly creating rrefId before forkId to make sure the order is
// deterministic, as the argument evaluation order is system and compiler // deterministic, as the argument evaluation order is system and compiler
// dependent. // dependent.
const auto rrefId = genGloballyUniqueId(); const auto rrefId = genGloballyUniqueId();
const auto forkId = genGloballyUniqueId(); const auto forkId = genGloballyUniqueId();
return createUserRRef<T>(ownerId, rrefId, forkId); return createUserRRef(ownerId, rrefId, forkId, type);
} }
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>( std::shared_ptr<UserRRef> RRefContext::createUserRRef(
worker_id_t ownerId);
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
py::object>(worker_id_t ownerId);
template <typename T>
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
worker_id_t ownerId, worker_id_t ownerId,
const RRefId& rrefId, const RRefId& rrefId,
const ForkId& forkId) { const ForkId& forkId,
const TypePtr& type) {
TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef."); TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
// RRefContext does not track user RRefs, it will be destructed when there // RRefContext does not track user RRefs, it will be destructed when there
// is no shared_ptrs pointing to it. // is no shared_ptrs pointing to it.
@ -119,20 +114,9 @@ std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
// The reason for not adding the pending user here is to put addPendingUser() // The reason for not adding the pending user here is to put addPendingUser()
// close to where the RPC occurs, and it is more clear to pair it with // close to where the RPC occurs, and it is more clear to pair it with
// deletePendingUser() in the response callback at the call site. // deletePendingUser() in the response callback at the call site.
return std::shared_ptr<UserRRef<T>>(new UserRRef<T>(ownerId, rrefId, forkId)); return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId, type));
} }
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId);
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
py::object>(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId);
void RRefContext::delUser( void RRefContext::delUser(
const worker_id_t owner, const worker_id_t owner,
const RRefId& rrefId, const RRefId& rrefId,
@ -150,27 +134,24 @@ void RRefContext::delUser(
} }
} }
template <typename T> std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(const RRefForkData& rfd) { const RRefForkData& rfd,
const TypePtr& type) {
auto& ownerId = rfd.ownerId_; auto& ownerId = rfd.ownerId_;
auto& rrefId = rfd.rrefId_; auto& rrefId = rfd.rrefId_;
auto& forkId = rfd.forkId_; auto& forkId = rfd.forkId_;
if (ownerId == getWorkerId()) { if (ownerId == getWorkerId()) {
return getOwnerRRef<T>(rrefId); auto ownerRRef = getOwnerRRef(rrefId);
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
return ownerRRef;
} else { } else {
return createUserRRef<T>(ownerId, rrefId, forkId); return createUserRRef(ownerId, rrefId, forkId, type);
} }
} }
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<IValue>( std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
const RRefForkData& rfd); const RRefId& rrefId,
const TypePtr& type) {
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<py::object>(
const RRefForkData& rfd);
template <typename T>
std::shared_ptr<OwnerRRef<T>> RRefContext::getOrCreateOwnerRRef(
const RRefId& rrefId) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const auto iter = owners_.find(rrefId); const auto iter = owners_.find(rrefId);
if (iter == owners_.end()) { if (iter == owners_.end()) {
@ -179,58 +160,40 @@ std::shared_ptr<OwnerRRef<T>> RRefContext::getOrCreateOwnerRRef(
// NB: cannot use make_shared here as the constructor of OwnerRRef is // NB: cannot use make_shared here as the constructor of OwnerRRef is
// private. // private.
auto rref = auto rref =
std::shared_ptr<OwnerRRef<T>>(new OwnerRRef<T>(getWorkerId(), rrefId)); std::shared_ptr<OwnerRRef>(new OwnerRRef(getWorkerId(), rrefId, type));
owners_[rref->rrefId()] = rref; owners_[rref->rrefId()] = rref;
ownerCV_.notify_all(); ownerCV_.notify_all();
return rref; return rref;
} else { } else {
// Scenario (2) retrieving an existing RRef // Scenario (2) retrieving an existing RRef
return std::static_pointer_cast<OwnerRRef<T>>(iter->second); auto ownerRRef = std::static_pointer_cast<OwnerRRef>(iter->second);
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
return ownerRRef;
} }
} }
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::getOrCreateOwnerRRef< std::shared_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
IValue>(const RRefId& rrefId);
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::
getOrCreateOwnerRRef<py::object>(const RRefId& rrefId);
template <typename T>
std::shared_ptr<OwnerRRef<T>> RRefContext::createOwnerRRef() {
// Don't add this OnwerRRef to the owners_ map yet, otherwise // Don't add this OnwerRRef to the owners_ map yet, otherwise
// it will never be removed from there. Instead, only add it to the // it will never be removed from there. Instead, only add it to the
// map in prepareChildFork, in case this local RRef is being passed // map in prepareChildFork, in case this local RRef is being passed
// to another worker. // to another worker.
return std::shared_ptr<OwnerRRef<T>>( return std::shared_ptr<OwnerRRef>(
new OwnerRRef<T>(getWorkerId(), genGloballyUniqueId())); new OwnerRRef(getWorkerId(), genGloballyUniqueId(), type));
} }
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::createOwnerRRef< std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
IValue>();
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::createOwnerRRef<
py::object>();
template <typename T>
std::shared_ptr<OwnerRRef<T>> RRefContext::getOwnerRRef(const RRefId& rrefId) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
const auto iter = owners_.find(rrefId); const auto iter = owners_.find(rrefId);
if (iter == owners_.end()) { if (iter == owners_.end()) {
// Scenario (1) RRef is used before it is created // Scenario (1) RRef is used before it is created
ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); }); ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); });
return std::static_pointer_cast<OwnerRRef<T>>(owners_[rrefId]); return std::static_pointer_cast<OwnerRRef>(owners_[rrefId]);
} else { } else {
// Scenario (2) retrieving an existing RRef // Scenario (2) retrieving an existing RRef
return std::static_pointer_cast<OwnerRRef<T>>(iter->second); return std::static_pointer_cast<OwnerRRef>(iter->second);
} }
} }
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::getOwnerRRef<IValue>(
const RRefId& rrefId);
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::getOwnerRRef<
py::object>(const RRefId& rrefId);
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) { RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
auto rfd = rref->fork(); auto rfd = rref->fork();
if (rref->isOwner()) { if (rref->isOwner()) {
@ -367,8 +330,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
}); });
} }
template <typename T> void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef>& rref) {
void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const auto& rrefId = rref->rrefId(); const auto& rrefId = rref->rrefId();
owners_[rrefId] = rref; owners_[rrefId] = rref;
@ -380,12 +342,6 @@ void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref) {
rrefForks.insert(rrefId); rrefForks.insert(rrefId);
} }
template void RRefContext::addSelfAsFork<IValue>(
std::shared_ptr<OwnerRRef<IValue>>& rref);
template void RRefContext::addSelfAsFork<py::object>(
std::shared_ptr<OwnerRRef<py::object>>& rref);
void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) { void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto& rrefForks = forks_[rrefId]; auto& rrefForks = forks_[rrefId];

View File

@ -47,26 +47,27 @@ class RRefContext {
} }
// create a ``UserRRef`` owned by the worker ``ownerId`` // create a ``UserRRef`` owned by the worker ``ownerId``
template <typename T> std::shared_ptr<UserRRef> createUserRRef(
std::shared_ptr<UserRRef<T>> createUserRRef(worker_id_t ownerId); worker_id_t ownerId,
const TypePtr& type);
// Convert an RRefForkData into an RRef. This RRef could be user or owner. // Convert an RRefForkData into an RRef. This RRef could be user or owner.
// This RRef could have already existed before, or could be created in this // This RRef could have already existed before, or could be created in this
// method. // method, we pass type here to validate or help the rref creation.
template <typename T> std::shared_ptr<RRef> getOrCreateRRef(
std::shared_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd); const RRefForkData& rfd,
const TypePtr& type);
// Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
// one. // one.
template <typename T> std::shared_ptr<OwnerRRef> getOrCreateOwnerRRef(
std::shared_ptr<OwnerRRef<T>> getOrCreateOwnerRRef(const RRefId& rrefId); const RRefId& rrefId,
const TypePtr& type);
// Create an empty owner rref of type T. // Create an empty owner rref of type.
template <typename T> std::shared_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
std::shared_ptr<OwnerRRef<T>> createOwnerRRef();
template <typename T> std::shared_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
std::shared_ptr<OwnerRRef<T>> getOwnerRRef(const RRefId& rrefId);
// Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
// making a remote call to self, which as for now, still goes through serde // making a remote call to self, which as for now, still goes through serde
@ -78,8 +79,7 @@ class RRefContext {
// and this could happen before the self remote call finishes. To prevent // and this could happen before the self remote call finishes. To prevent
// that, this API adds the RRefId as a ForkId, which will then delete the // that, this API adds the RRefId as a ForkId, which will then delete the
// ForkId when the self remote is done. // ForkId when the self remote is done.
template <typename T> void addSelfAsFork(std::shared_ptr<OwnerRRef>& rref);
void addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref);
// Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the // Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the
// ``OwnerRRef`` in a map to keep it alive. // ``OwnerRRef`` in a map to keep it alive.
@ -124,11 +124,11 @@ class RRefContext {
private: private:
RRefContext(std::shared_ptr<RpcAgent>); RRefContext(std::shared_ptr<RpcAgent>);
template <typename T> std::shared_ptr<UserRRef> createUserRRef(
std::shared_ptr<UserRRef<T>> createUserRRef(
worker_id_t ownerId, worker_id_t ownerId,
const RRefId& rrefId, const RRefId& rrefId,
const ForkId& forkId); const ForkId& forkId,
const TypePtr& type);
void finishForkRequest(const ForkId& forkId, worker_id_t parent); void finishForkRequest(const ForkId& forkId, worker_id_t parent);

View File

@ -6,6 +6,7 @@
#include <torch/csrc/distributed/rpc/rref_context.h> #include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_proto.h> #include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/utils.h> #include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch { namespace torch {
namespace distributed { namespace distributed {
@ -19,9 +20,10 @@ constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
constexpr int PARENT_IDX = 5; // index of parent in the tuple constexpr int PARENT_IDX = 5; // index of parent in the tuple
constexpr int TYPE_IDX = 6; // index of parent in the tuple
// NB: if more fields are added, make sure this field is also bumped // NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
} // namespace } // namespace
std::atomic<local_id_t> RRefContext::nextLocalId_{0}; std::atomic<local_id_t> RRefContext::nextLocalId_{0};
@ -32,8 +34,13 @@ RRefForkData::RRefForkData(
worker_id_t ownerId, worker_id_t ownerId,
const RRefId& rrefId, const RRefId& rrefId,
const ForkId& forkId, const ForkId& forkId,
worker_id_t parent) worker_id_t parent,
: ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {} std::string type_str)
: ownerId_(ownerId),
rrefId_(rrefId),
forkId_(forkId),
parent_(parent),
type_str_(std::move(type_str)) {}
py::tuple RRefForkData::toPyTuple() const { py::tuple RRefForkData::toPyTuple() const {
return py::make_tuple( return py::make_tuple(
@ -42,7 +49,8 @@ py::tuple RRefForkData::toPyTuple() const {
rrefId_.localId_, rrefId_.localId_,
forkId_.createdOn_, forkId_.createdOn_,
forkId_.localId_, forkId_.localId_,
parent_); parent_,
type_str_);
} }
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) { RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
@ -57,29 +65,39 @@ RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
const RRefId& forkId = RRefId( const RRefId& forkId = RRefId(
t[FORKID_ON_IDX].cast<worker_id_t>(), t[FORKID_ON_IDX].cast<worker_id_t>(),
t[FORKID_ID_IDX].cast<local_id_t>()); t[FORKID_ID_IDX].cast<local_id_t>());
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>(); worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
return RRefForkData(ownerId, rrefId, forkId, parent); const std::string& typeStr = t[TYPE_IDX].cast<std::string>();
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
} }
////////////////////////////// RRef ///////////////////////////////////// ////////////////////////////// RRef /////////////////////////////////////
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId) RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
: RRefInterface(), ownerId_(ownerId), rrefId_(rrefId) {} : RRefInterface(),
ownerId_(ownerId),
rrefId_(rrefId),
type_(std::move(type)) {}
RRefForkData RRef::fork() const { RRefForkData RRef::fork() const {
auto& ctx = RRefContext::getInstance(); auto& ctx = RRefContext::getInstance();
return RRefForkData( return RRefForkData(
ownerId_, rrefId_, ctx.genGloballyUniqueId(), ctx.getWorkerId()); ownerId_,
rrefId_,
ctx.genGloballyUniqueId(),
ctx.getWorkerId(),
type_->str());
} }
////////////////////////// UserRRef ///////////////////////////////////// ////////////////////////// UserRRef /////////////////////////////////////
template <typename T> UserRRef::UserRRef(
UserRRef<T>::UserRRef(
worker_id_t ownerId, worker_id_t ownerId,
const RRefId& rrefId, const RRefId& rrefId,
const ForkId& forkId) const ForkId& forkId,
: RRef(ownerId, rrefId), forkId_(forkId) { TypePtr type)
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
// Do nothing, // Do nothing,
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
// a RREF_FORK_REQUEST message to the owner. // a RREF_FORK_REQUEST message to the owner.
@ -87,8 +105,7 @@ UserRRef<T>::UserRRef(
// properly notify the owner. // properly notify the owner.
} }
template <typename T> UserRRef::~UserRRef() {
UserRRef<T>::~UserRRef() {
try { try {
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_); RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
@ -102,80 +119,65 @@ UserRRef<T>::~UserRRef() {
} }
} }
template <typename T> const ForkId& UserRRef::forkId() const {
const ForkId& UserRRef<T>::forkId() const {
return forkId_; return forkId_;
} }
template <> IValue UserRRef::toHere() {
IValue UserRRef<IValue>::toHere() {
auto agent = RpcAgent::getDefaultRpcAgent(); auto agent = RpcAgent::getDefaultRpcAgent();
// ScriptRRefFetchCall message always carries autograd context id even if // ScriptRRefFetchCall message always carries autograd context id even if
// the message itself does not contain any tensor, because the response would // the message itself does not contain any tensor, because the response would
// potentially contain tensors. // potentially contain tensors.
Message msgToSend;
if (isPyObj()) {
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
} else {
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
}
auto futureResponse = autograd::sendMessageWithAutograd( auto futureResponse = autograd::sendMessageWithAutograd(
*agent, *agent,
agent->getWorkerInfo(ownerId_), agent->getWorkerInfo(ownerId_),
ScriptRRefFetchCall(ownerId_, rrefId()).toMessage(), std::move(msgToSend),
true /* forceGradRecording */); true /* forceGradRecording */);
const Message& message = futureResponse->wait(); const Message& message = futureResponse->wait();
MessageType msgType = message.type(); MessageType msgType = message.type();
auto response = deserializeResponse(message, msgType); auto response = deserializeResponse(message, msgType);
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
msgType == MessageType::SCRIPT_RREF_FETCH_RET, msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
"Message type should be SCRIPT_RREF_FETCH_RET."); msgType == MessageType::PYTHON_RREF_FETCH_RET,
"Message type should either be SCRIPT_RREF_FETCH_RET "
"or PYTHON_RREF_FETCH_RET");
RpcCommandBase& rpc = *response; RpcCommandBase& rpc = *response;
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc); if (isPyObj()) {
return rfr.values().front(); auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
return jit::toIValue(
PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr.values())),
PyObjectType::get());
} else {
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
return rfr.values().front();
}
} }
template <>
py::object UserRRef<py::object>::toHere() {
auto agent = RpcAgent::getDefaultRpcAgent();
// PythonRRefFetchCall message always carries autograd context id even if
// the message itself does not contain any tensor, because the response would
// potentially contain tensors.
auto futureResponse = autograd::sendMessageWithAutograd(
*agent,
agent->getWorkerInfo(ownerId_),
PythonRRefFetchCall(ownerId_, rrefId()).toMessage(),
true /* forceGradRecording */);
const Message& message = futureResponse->wait();
MessageType msgType = message.type();
auto response = deserializeResponse(message, msgType);
TORCH_INTERNAL_ASSERT(
msgType == MessageType::PYTHON_RREF_FETCH_RET,
"Message type should be PYTHON_RREF_FETCH_RET.");
RpcCommandBase& rpc = *response;
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr.values()));
}
template class UserRRef<IValue>;
template class UserRRef<py::object>;
////////////////////////// OwnerRRef ///////////////////////////////////// ////////////////////////// OwnerRRef /////////////////////////////////////
template <typename T> const IValue& OwnerRRef::getValue() const {
const T& OwnerRRef<T>::getValue() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
valueCV_.wait(lock, [this] { return value_.has_value(); }); valueCV_.wait(lock, [this] { return value_.has_value(); });
return value_.value(); return value_.value();
} }
template <typename T> bool OwnerRRef::hasValue() const {
bool OwnerRRef<T>::hasValue() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return value_.has_value(); return value_.has_value();
} }
template <typename T> std::shared_ptr<FutureMessage> OwnerRRef::getFuture() {
std::shared_ptr<FutureMessage> OwnerRRef<T>::getFuture() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (future_.get()) { if (future_.get()) {
return future_; return future_;
@ -189,8 +191,7 @@ std::shared_ptr<FutureMessage> OwnerRRef<T>::getFuture() {
return ret; return ret;
} }
template <typename T> void OwnerRRef::setValue(IValue&& value) {
void OwnerRRef<T>::setValue(T&& value) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
value_ = std::move(value); value_ = std::move(value);
std::shared_ptr<FutureMessage> future; std::shared_ptr<FutureMessage> future;
@ -202,9 +203,6 @@ void OwnerRRef<T>::setValue(T&& value) {
} }
} }
template class OwnerRRef<IValue>;
template class OwnerRRef<py::object>;
} // namespace rpc } // namespace rpc
} // namespace distributed } // namespace distributed
} // namespace torch } // namespace torch

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <ATen/core/jit_type.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
#include <torch/csrc/distributed/rpc/message.h> #include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h> #include <torch/csrc/distributed/rpc/rpc_agent.h>
@ -15,7 +16,6 @@ namespace rpc {
class RRef; class RRef;
class RRefContext; class RRefContext;
template <typename T>
class UserRRef; class UserRRef;
// Represents fork of an RRef to be sent over the wire. // Represents fork of an RRef to be sent over the wire.
@ -27,24 +27,21 @@ struct RRefForkData {
const RRefId rrefId_; const RRefId rrefId_;
const ForkId forkId_; const ForkId forkId_;
const worker_id_t parent_; const worker_id_t parent_;
const std::string type_str_;
private: private:
friend class RRef; friend class RRef;
friend class RRefContext; friend class RRefContext;
template <typename T>
friend class UserRRef; friend class UserRRef;
RRefForkData( RRefForkData(
worker_id_t ownerId, worker_id_t ownerId,
const RRefId& rrefId_, const RRefId& rrefId_,
const ForkId& forkId_, const ForkId& forkId_,
worker_id_t parent); worker_id_t parent,
std::string type_str);
}; };
static_assert(
C10_IS_TRIVIALLY_COPYABLE(RRefForkData),
"RRefForkData must be trivially copyable");
// Note [RRef Protocol] // Note [RRef Protocol]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~
// //
@ -207,25 +204,32 @@ class RRef : public RRefInterface {
return rrefId_; return rrefId_;
} }
// returns true if this RRef holds an py::object, false if IValue inline bool isPyObj() {
virtual bool isPyObj() = 0; return type_ == PyObjectType::get();
}
inline const TypePtr type() {
return type_;
}
protected: protected:
friend class RRefContext; friend class RRefContext;
RRef(worker_id_t ownerId, const RRefId& rrefId); RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
RRefForkData fork() const; RRefForkData fork() const;
const worker_id_t ownerId_; const worker_id_t ownerId_;
const RRefId rrefId_; const RRefId rrefId_;
// type field to denote the type of the element that the RRef is holding
// it could be any TypePtr that JIT support, including PyObjectType
const TypePtr type_;
}; };
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef`` // also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
// never owns the real value, the only way to get the value of the ``RRef`` is // never owns the real value, the only way to get the value of the ``RRef`` is
// to call ``to_here()`` and get a copy.. // to call ``to_here()`` and get a copy..
template <typename T>
class UserRRef final : public RRef { class UserRRef final : public RRef {
public: public:
UserRRef(const UserRRef& other) = delete; UserRRef(const UserRRef& other) = delete;
@ -237,16 +241,12 @@ class UserRRef final : public RRef {
return false; return false;
} }
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
// Returns the globally unique ForkId of this RRef // Returns the globally unique ForkId of this RRef
const ForkId& forkId() const; const ForkId& forkId() const;
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready // Get of copy of the value from the ``OwnerRRef``. If the value is not ready
// yet, this call will block. // yet, this call will block.
T toHere(); IValue toHere();
// Upon destruction, this ``UserRRef`` will tell the owner to deref. // Upon destruction, this ``UserRRef`` will tell the owner to deref.
~UserRRef() override; ~UserRRef() override;
@ -254,14 +254,17 @@ class UserRRef final : public RRef {
private: private:
friend class RRefContext; friend class RRefContext;
UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId); UserRRef(
worker_id_t ownerId,
const RRefId& rrefId,
const ForkId& forkId,
TypePtr type);
const ForkId forkId_; const ForkId forkId_;
}; };
// Keep the template only on the derived class because ``RRefContext`` needs to // Keep the template only on the derived class because ``RRefContext`` needs to
// erase the type on ``RRef`` and keep them in one map. // erase the type on ``RRef`` and keep them in one map.
template <typename T>
class OwnerRRef final : public RRef { class OwnerRRef final : public RRef {
public: public:
OwnerRRef(const OwnerRRef& other) = delete; OwnerRRef(const OwnerRRef& other) = delete;
@ -273,18 +276,14 @@ class OwnerRRef final : public RRef {
return true; return true;
} }
inline bool isPyObj() override {
return std::is_same<T, py::object>::value;
}
// Get a constant reference of the real value. This method will block if the // Get a constant reference of the real value. This method will block if the
// value is not ready. This method does not need GIL as it does not create // value is not ready. This method does not need GIL as it does not create
// any new py::object. // any new py::object.
const T& getValue() const; const IValue& getValue() const;
// Set the value of this ``OwnerRRef``. This method does not need GIL as it // Set the value of this ``OwnerRRef``. This method does not need GIL as it
// does not create any new py::object. // does not create any new py::object.
void setValue(T&& value); void setValue(IValue&& value);
// Has a value been set? // Has a value been set?
bool hasValue() const; bool hasValue() const;
@ -294,15 +293,19 @@ class OwnerRRef final : public RRef {
private: private:
friend class RRefContext; friend class RRefContext;
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId) OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
: OwnerRRef(ownerId, rrefId, {}) {} : OwnerRRef(ownerId, rrefId, type, {}) {}
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional<T> value) OwnerRRef(
: RRef(ownerId, rrefId) { worker_id_t ownerId,
const RRefId& rrefId,
TypePtr type,
c10::optional<IValue> value)
: RRef(ownerId, rrefId, std::move(type)) {
value_ = std::move(value); value_ = std::move(value);
} }
c10::optional<T> value_; c10::optional<IValue> value_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
mutable std::condition_variable valueCV_; mutable std::condition_variable valueCV_;
std::shared_ptr<FutureMessage> future_; std::shared_ptr<FutureMessage> future_;