mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[rpc] Remove template on RRef and add Type to RRef creation (#30630)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30630 This remove template and all the specializations it have in rpc, we universally use IValue as the inner value since we support making python object to be hold inside IValue. This will also ensure that we have the correct type information when creating the RRef, we use the return type from the schema when creating userRRef and OwnerRRef, it will enable IValue to always have the correct type if the IValue is the RRef object (next PR) Test Plan: Imported from OSS Differential Revision: D19502235 fbshipit-source-id: 0d5decae8a9767e0893f3b8b6456b231653be3c5
This commit is contained in:
committed by
Facebook Github Bot
parent
ef2d4e67d1
commit
b474c351dd
@ -8,21 +8,6 @@
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
// Constants below are used in PyRRef pickling and unpickling. PyRRef is
|
||||
// converted into a py::tuple in pickling, and reconstructed from the py::tuple
|
||||
// in unpickling.
|
||||
constexpr int RFD_IDX = 0; // index of RRefForkData
|
||||
constexpr int TYPE_IDX = 1; // index of type (py::object or IValue)
|
||||
|
||||
// number of data fields in the py::tuple.
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RREF_TUPLE_SIZE = 2;
|
||||
|
||||
} // namespace
|
||||
|
||||
/////////////////////////// PyRRef //////////////////////////////////
|
||||
|
||||
PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
@ -31,9 +16,11 @@ PyRRef::PyRRef(std::shared_ptr<RRef> rref) : rref_(std::move(rref)) {
|
||||
|
||||
PyRRef::PyRRef(const py::object& value)
|
||||
: PyRRef([&value]() {
|
||||
auto rref = RRefContext::getInstance().createOwnerRRef<py::object>();
|
||||
auto rref =
|
||||
RRefContext::getInstance().createOwnerRRef(PyObjectType::get());
|
||||
py::object copy(value); // increases refcount
|
||||
rref->setValue(std::move(copy));
|
||||
IValue py_ivalue = jit::toIValue(std::move(copy), PyObjectType::get());
|
||||
rref->setValue(std::move(py_ivalue));
|
||||
return rref;
|
||||
}()) {}
|
||||
|
||||
@ -52,10 +39,10 @@ py::object PyRRef::toHere() {
|
||||
if (rref_->isPyObj()) {
|
||||
// UserRRef<py::object>::toHere() calls python_rpc_handler which acquires
|
||||
// GIL.
|
||||
return std::static_pointer_cast<UserRRef<py::object>>(rref_)->toHere();
|
||||
return jit::toPyObject(
|
||||
std::static_pointer_cast<UserRRef>(rref_)->toHere());
|
||||
} else {
|
||||
IValue value =
|
||||
std::static_pointer_cast<UserRRef<IValue>>(rref_)->toHere();
|
||||
IValue value = std::static_pointer_cast<UserRRef>(rref_)->toHere();
|
||||
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object
|
||||
@ -74,9 +61,8 @@ py::object PyRRef::localValue() {
|
||||
owner().name_);
|
||||
|
||||
if (rref_->isPyObj()) {
|
||||
const py::object& value =
|
||||
std::dynamic_pointer_cast<OwnerRRef<py::object>>(rref_)->getValue();
|
||||
|
||||
const py::object& value = jit::toPyObject(
|
||||
std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue());
|
||||
PythonRpcHandler::getInstance().handleException(value);
|
||||
{
|
||||
// acquiring GIL as the return statement construct a new py::object from
|
||||
@ -85,8 +71,7 @@ py::object PyRRef::localValue() {
|
||||
return value;
|
||||
}
|
||||
} else {
|
||||
auto value =
|
||||
std::dynamic_pointer_cast<OwnerRRef<IValue>>(rref_)->getValue();
|
||||
auto value = std::dynamic_pointer_cast<OwnerRRef>(rref_)->getValue();
|
||||
{
|
||||
// acquiring GIL as torch::jit::toPyObject creates new py::object without
|
||||
// grabbing the GIL.
|
||||
@ -101,13 +86,9 @@ std::string PyRRef::str() const {
|
||||
if (rref_->isOwner()) {
|
||||
ss << "OwnerRRef(" << rref_->rrefId() << ")";
|
||||
} else {
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId() << ", ForkId = ";
|
||||
if (rref_->isPyObj()) {
|
||||
ss << std::static_pointer_cast<UserRRef<py::object>>(rref_)->forkId();
|
||||
} else {
|
||||
ss << std::static_pointer_cast<UserRRef<IValue>>(rref_)->forkId();
|
||||
}
|
||||
ss << ")";
|
||||
ss << "UserRRef(RRefId = " << rref_->rrefId()
|
||||
<< ", ForkId = " << std::static_pointer_cast<UserRRef>(rref_)->forkId()
|
||||
<< ")";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
@ -119,21 +100,16 @@ py::tuple PyRRef::pickle() const {
|
||||
// a counter example, checkpointing a model with RRefs should not trigger
|
||||
// forks to be added as a fork or a child.
|
||||
auto rfd = ctx.prepareChildFork(rref_);
|
||||
return py::make_tuple(rfd.toPyTuple(), rref_->isPyObj());
|
||||
return rfd.toPyTuple();
|
||||
}
|
||||
|
||||
PyRRef PyRRef::unpickle(const py::tuple& t) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
t.size() == RREF_TUPLE_SIZE, "Pickled RRef must contain 2 numbers.");
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
auto rfd = RRefForkData::fromPyTuple(t[RFD_IDX].cast<py::tuple>());
|
||||
auto rfd = RRefForkData::fromPyTuple(t.cast<py::tuple>());
|
||||
std::shared_ptr<RRef> rref = nullptr;
|
||||
bool isPyObj = t[TYPE_IDX].cast<bool>();
|
||||
if (isPyObj) {
|
||||
rref = ctx.getOrCreateRRef<py::object>(rfd);
|
||||
} else {
|
||||
rref = ctx.getOrCreateRRef<IValue>(rfd);
|
||||
}
|
||||
TypePtr rref_type =
|
||||
PythonRpcHandler::getInstance().parseTypeFromStr(rfd.type_str_);
|
||||
rref = ctx.getOrCreateRRef(rfd, rref_type);
|
||||
|
||||
ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
|
||||
return PyRRef(std::move(rref));
|
||||
|
@ -159,13 +159,14 @@ PyRRef pyRemoteBuiltin(
|
||||
const py::kwargs& kwargs) {
|
||||
Stack stack;
|
||||
auto op = matchBuiltinOp(opName, args, kwargs, stack);
|
||||
TypePtr ret_type = op->schema().returns()[0].type();
|
||||
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
// TODO: support creating RRefs on a local object.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
ctx.getWorkerId() != dst.id_,
|
||||
"Does not support creating RRef on self yet.");
|
||||
auto userRRef = ctx.createUserRRef<IValue>(dst.id_);
|
||||
auto userRRef = ctx.createUserRRef(dst.id_, ret_type);
|
||||
|
||||
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
|
||||
op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
|
||||
@ -205,7 +206,7 @@ PyRRef pyRemotePythonUdf(
|
||||
auto serializedPyObj =
|
||||
SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
|
||||
if (ctx.getWorkerId() != dst.id_) {
|
||||
auto userRRef = ctx.createUserRRef<py::object>(dst.id_);
|
||||
auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
|
||||
ctx.addPendingUser(userRRef->forkId(), userRRef);
|
||||
auto fm = sendPythonRemoteCall(
|
||||
agent,
|
||||
@ -218,7 +219,7 @@ PyRRef pyRemotePythonUdf(
|
||||
fm->addCallback(finishAcceptUserRRef);
|
||||
return PyRRef(userRRef);
|
||||
} else {
|
||||
auto ownerRRef = ctx.createOwnerRRef<py::object>();
|
||||
auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
|
||||
// prevent this owner RRef be deleted due to other forks
|
||||
ctx.addSelfAsFork(ownerRRef);
|
||||
auto fm = sendPythonRemoteCall(
|
||||
|
@ -8,6 +8,28 @@ namespace rpc {
|
||||
|
||||
namespace {
|
||||
|
||||
// PythonTypeResolver that inherits from Script::Resolver to
|
||||
// support resolving types together with ScriptTypeParser.
|
||||
struct PythonTypeResolver : public jit::script::Resolver {
|
||||
std::shared_ptr<jit::script::SugaredValue> resolveValue(
|
||||
const std::string& /* unused */,
|
||||
Function& /* unused */,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "RPC Type resolver does not need to resolve value");
|
||||
}
|
||||
|
||||
TypePtr resolveType(
|
||||
const std::string& name,
|
||||
const jit::SourceRange& /* unused */) override {
|
||||
if (name == "PyObject") {
|
||||
return PyObjectType::get();
|
||||
}
|
||||
auto python_cu = torch::jit::get_python_cu();
|
||||
return python_cu->get_type(name);
|
||||
}
|
||||
};
|
||||
|
||||
py::object getFunction(const py::object& module, const char* name) {
|
||||
py::object fn = module.attr(name);
|
||||
TORCH_CHECK(
|
||||
@ -28,6 +50,8 @@ PythonRpcHandler::PythonRpcHandler() {
|
||||
pySerialize_ = getFunction(module, "serialize");
|
||||
pyHandleException_ = getFunction(module, "_handle_exception");
|
||||
jitCompilationUnit_ = torch::jit::get_python_cu();
|
||||
typeParser_ = std::make_shared<jit::script::ScriptTypeParser>(
|
||||
std::make_shared<PythonTypeResolver>());
|
||||
}
|
||||
|
||||
void PythonRpcHandler::cleanup() {
|
||||
@ -95,6 +119,10 @@ void PythonRpcHandler::handleException(const py::object& obj) {
|
||||
pyHandleException_(obj);
|
||||
}
|
||||
|
||||
TypePtr PythonRpcHandler::parseTypeFromStr(const std::string& type_str) {
|
||||
return typeParser_->parseType(type_str);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/script/script_type_parser.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
@ -59,6 +60,17 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit();
|
||||
|
||||
// Parse the string to recover the jit_type, this is used for RRef python
|
||||
// pickling/unpickling type recovery. The type string inference rule is as
|
||||
// follows:
|
||||
// 1. first try to parse if this is primitive types.
|
||||
// i.e. TensorType, IntType, PyObjectType, etc.
|
||||
// 2. if not primitive type, we query the python_cu to see if it is a
|
||||
// class type or interface type registered in python
|
||||
// We use a ScriptTypeParser instance with custom PythonTypeResolver
|
||||
// to resolve types according to the above rules.
|
||||
TypePtr parseTypeFromStr(const std::string& typeStr);
|
||||
|
||||
private:
|
||||
PythonRpcHandler();
|
||||
~PythonRpcHandler() = default;
|
||||
@ -102,6 +114,10 @@ class PYBIND11_EXPORT PythonRpcHandler {
|
||||
// We import the compilation unit here only once for less cost and thread
|
||||
// safety.
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_;
|
||||
|
||||
// jit type parser to parse type_str back to TypePtr for RRef type
|
||||
// recovery when pickling and unpickling RRef
|
||||
std::shared_ptr<jit::script::ScriptTypeParser> typeParser_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
||||
#include <torch/csrc/distributed/rpc/script_resp.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
@ -82,7 +83,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
auto& src = static_cast<ScriptRemoteCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
|
||||
auto ownerRRef = ctx.getOrCreateOwnerRRef<IValue>(src.retRRefId());
|
||||
TypePtr ret_type = src.op()->schema().returns()[0].type();
|
||||
auto ownerRRef = ctx.getOrCreateOwnerRRef(src.retRRefId(), ret_type);
|
||||
|
||||
// TODO: make this asynchronous
|
||||
// src is only alive within this block, use reference to avoid copy
|
||||
@ -106,10 +108,13 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
auto forkId = ForkId::fromIValue(prc.retForkId());
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
|
||||
auto ownerRRef = ctx.getOrCreateOwnerRRef<py::object>(rrefId);
|
||||
auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get());
|
||||
|
||||
ownerRRef->setValue(
|
||||
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()));
|
||||
IValue py_ivalue = jit::toIValue(
|
||||
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()),
|
||||
PyObjectType::get());
|
||||
|
||||
ownerRRef->setValue(std::move(py_ivalue));
|
||||
|
||||
if (rrefId != forkId) {
|
||||
// Caller is a user and callee is the owner, add fork
|
||||
@ -127,8 +132,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
||||
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
std::shared_ptr<OwnerRRef<IValue>> rref =
|
||||
ctx.getOwnerRRef<IValue>(srf.rrefId());
|
||||
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
|
||||
if (rref->hasValue()) { // optional fast-path
|
||||
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
|
||||
}
|
||||
@ -149,11 +153,10 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
||||
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
std::shared_ptr<OwnerRRef<py::object>> rref =
|
||||
ctx.getOwnerRRef<py::object>(prf.rrefId());
|
||||
std::shared_ptr<OwnerRRef> rref = ctx.getOwnerRRef(prf.rrefId());
|
||||
if (rref->hasValue()) { // optional fast-path
|
||||
SerializedPyObj result =
|
||||
PythonRpcHandler::getInstance().serialize(rref->getValue());
|
||||
SerializedPyObj result = PythonRpcHandler::getInstance().serialize(
|
||||
jit::toPyObject(rref->getValue()));
|
||||
return wrap(PythonRRefFetchRet(result.toIValues()).toMessage());
|
||||
}
|
||||
|
||||
@ -165,8 +168,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
[responseFuture, messageId, rref](
|
||||
const rpc::Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& /* unused */) {
|
||||
SerializedPyObj result =
|
||||
PythonRpcHandler::getInstance().serialize(rref->getValue());
|
||||
SerializedPyObj result = PythonRpcHandler::getInstance().serialize(
|
||||
jit::toPyObject(rref->getValue()));
|
||||
Message m = PythonRRefFetchRet(result.toIValues()).toMessage();
|
||||
m.setId(messageId);
|
||||
responseFuture->markCompleted(m);
|
||||
|
@ -83,28 +83,23 @@ void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(worker_id_t ownerId) {
|
||||
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const TypePtr& type) {
|
||||
TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
|
||||
// Explicitly creating rrefId before forkId to make sure the order is
|
||||
// deterministic, as the argument evaluation order is system and compiler
|
||||
// dependent.
|
||||
const auto rrefId = genGloballyUniqueId();
|
||||
const auto forkId = genGloballyUniqueId();
|
||||
return createUserRRef<T>(ownerId, rrefId, forkId);
|
||||
return createUserRRef(ownerId, rrefId, forkId, type);
|
||||
}
|
||||
|
||||
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
|
||||
worker_id_t ownerId);
|
||||
|
||||
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
|
||||
py::object>(worker_id_t ownerId);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
|
||||
std::shared_ptr<UserRRef> RRefContext::createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId) {
|
||||
const ForkId& forkId,
|
||||
const TypePtr& type) {
|
||||
TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
|
||||
// RRefContext does not track user RRefs, it will be destructed when there
|
||||
// is no shared_ptrs pointing to it.
|
||||
@ -119,20 +114,9 @@ std::shared_ptr<UserRRef<T>> RRefContext::createUserRRef(
|
||||
// The reason for not adding the pending user here is to put addPendingUser()
|
||||
// close to where the RPC occurs, and it is more clear to pair it with
|
||||
// deletePendingUser() in the response callback at the call site.
|
||||
return std::shared_ptr<UserRRef<T>>(new UserRRef<T>(ownerId, rrefId, forkId));
|
||||
return std::shared_ptr<UserRRef>(new UserRRef(ownerId, rrefId, forkId, type));
|
||||
}
|
||||
|
||||
template std::shared_ptr<UserRRef<IValue>> RRefContext::createUserRRef<IValue>(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId);
|
||||
|
||||
template std::shared_ptr<UserRRef<py::object>> RRefContext::createUserRRef<
|
||||
py::object>(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId);
|
||||
|
||||
void RRefContext::delUser(
|
||||
const worker_id_t owner,
|
||||
const RRefId& rrefId,
|
||||
@ -150,27 +134,24 @@ void RRefContext::delUser(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(const RRefForkData& rfd) {
|
||||
std::shared_ptr<RRef> RRefContext::getOrCreateRRef(
|
||||
const RRefForkData& rfd,
|
||||
const TypePtr& type) {
|
||||
auto& ownerId = rfd.ownerId_;
|
||||
auto& rrefId = rfd.rrefId_;
|
||||
auto& forkId = rfd.forkId_;
|
||||
if (ownerId == getWorkerId()) {
|
||||
return getOwnerRRef<T>(rrefId);
|
||||
auto ownerRRef = getOwnerRRef(rrefId);
|
||||
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
|
||||
return ownerRRef;
|
||||
} else {
|
||||
return createUserRRef<T>(ownerId, rrefId, forkId);
|
||||
return createUserRRef(ownerId, rrefId, forkId, type);
|
||||
}
|
||||
}
|
||||
|
||||
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<IValue>(
|
||||
const RRefForkData& rfd);
|
||||
|
||||
template std::shared_ptr<RRef> RRefContext::getOrCreateRRef<py::object>(
|
||||
const RRefForkData& rfd);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> RRefContext::getOrCreateOwnerRRef(
|
||||
const RRefId& rrefId) {
|
||||
std::shared_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const TypePtr& type) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto iter = owners_.find(rrefId);
|
||||
if (iter == owners_.end()) {
|
||||
@ -179,58 +160,40 @@ std::shared_ptr<OwnerRRef<T>> RRefContext::getOrCreateOwnerRRef(
|
||||
// NB: cannot use make_shared here as the constructor of OwnerRRef is
|
||||
// private.
|
||||
auto rref =
|
||||
std::shared_ptr<OwnerRRef<T>>(new OwnerRRef<T>(getWorkerId(), rrefId));
|
||||
std::shared_ptr<OwnerRRef>(new OwnerRRef(getWorkerId(), rrefId, type));
|
||||
owners_[rref->rrefId()] = rref;
|
||||
ownerCV_.notify_all();
|
||||
return rref;
|
||||
} else {
|
||||
// Scenario (2) retrieving an existing RRef
|
||||
return std::static_pointer_cast<OwnerRRef<T>>(iter->second);
|
||||
auto ownerRRef = std::static_pointer_cast<OwnerRRef>(iter->second);
|
||||
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
|
||||
return ownerRRef;
|
||||
}
|
||||
}
|
||||
|
||||
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::getOrCreateOwnerRRef<
|
||||
IValue>(const RRefId& rrefId);
|
||||
|
||||
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::
|
||||
getOrCreateOwnerRRef<py::object>(const RRefId& rrefId);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> RRefContext::createOwnerRRef() {
|
||||
std::shared_ptr<OwnerRRef> RRefContext::createOwnerRRef(const TypePtr& type) {
|
||||
// Don't add this OnwerRRef to the owners_ map yet, otherwise
|
||||
// it will never be removed from there. Instead, only add it to the
|
||||
// map in prepareChildFork, in case this local RRef is being passed
|
||||
// to another worker.
|
||||
return std::shared_ptr<OwnerRRef<T>>(
|
||||
new OwnerRRef<T>(getWorkerId(), genGloballyUniqueId()));
|
||||
return std::shared_ptr<OwnerRRef>(
|
||||
new OwnerRRef(getWorkerId(), genGloballyUniqueId(), type));
|
||||
}
|
||||
|
||||
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::createOwnerRRef<
|
||||
IValue>();
|
||||
|
||||
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::createOwnerRRef<
|
||||
py::object>();
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> RRefContext::getOwnerRRef(const RRefId& rrefId) {
|
||||
std::shared_ptr<OwnerRRef> RRefContext::getOwnerRRef(const RRefId& rrefId) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
const auto iter = owners_.find(rrefId);
|
||||
if (iter == owners_.end()) {
|
||||
// Scenario (1) RRef is used before it is created
|
||||
ownerCV_.wait(lock, [&] { return owners_.find(rrefId) != owners_.end(); });
|
||||
return std::static_pointer_cast<OwnerRRef<T>>(owners_[rrefId]);
|
||||
return std::static_pointer_cast<OwnerRRef>(owners_[rrefId]);
|
||||
} else {
|
||||
// Scenario (2) retrieving an existing RRef
|
||||
return std::static_pointer_cast<OwnerRRef<T>>(iter->second);
|
||||
return std::static_pointer_cast<OwnerRRef>(iter->second);
|
||||
}
|
||||
}
|
||||
|
||||
template std::shared_ptr<OwnerRRef<IValue>> RRefContext::getOwnerRRef<IValue>(
|
||||
const RRefId& rrefId);
|
||||
|
||||
template std::shared_ptr<OwnerRRef<py::object>> RRefContext::getOwnerRRef<
|
||||
py::object>(const RRefId& rrefId);
|
||||
|
||||
RRefForkData RRefContext::prepareChildFork(const std::shared_ptr<RRef>& rref) {
|
||||
auto rfd = rref->fork();
|
||||
if (rref->isOwner()) {
|
||||
@ -367,8 +330,7 @@ void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref) {
|
||||
void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef>& rref) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
const auto& rrefId = rref->rrefId();
|
||||
owners_[rrefId] = rref;
|
||||
@ -380,12 +342,6 @@ void RRefContext::addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref) {
|
||||
rrefForks.insert(rrefId);
|
||||
}
|
||||
|
||||
template void RRefContext::addSelfAsFork<IValue>(
|
||||
std::shared_ptr<OwnerRRef<IValue>>& rref);
|
||||
|
||||
template void RRefContext::addSelfAsFork<py::object>(
|
||||
std::shared_ptr<OwnerRRef<py::object>>& rref);
|
||||
|
||||
void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto& rrefForks = forks_[rrefId];
|
||||
|
@ -47,26 +47,27 @@ class RRefContext {
|
||||
}
|
||||
|
||||
// create a ``UserRRef`` owned by the worker ``ownerId``
|
||||
template <typename T>
|
||||
std::shared_ptr<UserRRef<T>> createUserRRef(worker_id_t ownerId);
|
||||
std::shared_ptr<UserRRef> createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const TypePtr& type);
|
||||
|
||||
// Convert an RRefForkData into an RRef. This RRef could be user or owner.
|
||||
// This RRef could have already existed before, or could be created in this
|
||||
// method.
|
||||
template <typename T>
|
||||
std::shared_ptr<RRef> getOrCreateRRef(const RRefForkData& rfd);
|
||||
// method, we pass type here to validate or help the rref creation.
|
||||
std::shared_ptr<RRef> getOrCreateRRef(
|
||||
const RRefForkData& rfd,
|
||||
const TypePtr& type);
|
||||
|
||||
// Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
|
||||
// one.
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> getOrCreateOwnerRRef(const RRefId& rrefId);
|
||||
std::shared_ptr<OwnerRRef> getOrCreateOwnerRRef(
|
||||
const RRefId& rrefId,
|
||||
const TypePtr& type);
|
||||
|
||||
// Create an empty owner rref of type T.
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> createOwnerRRef();
|
||||
// Create an empty owner rref of type.
|
||||
std::shared_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<OwnerRRef<T>> getOwnerRRef(const RRefId& rrefId);
|
||||
std::shared_ptr<OwnerRRef> getOwnerRRef(const RRefId& rrefId);
|
||||
|
||||
// Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
|
||||
// making a remote call to self, which as for now, still goes through serde
|
||||
@ -78,8 +79,7 @@ class RRefContext {
|
||||
// and this could happen before the self remote call finishes. To prevent
|
||||
// that, this API adds the RRefId as a ForkId, which will then delete the
|
||||
// ForkId when the self remote is done.
|
||||
template <typename T>
|
||||
void addSelfAsFork(std::shared_ptr<OwnerRRef<T>>& rref);
|
||||
void addSelfAsFork(std::shared_ptr<OwnerRRef>& rref);
|
||||
|
||||
// Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the
|
||||
// ``OwnerRRef`` in a map to keep it alive.
|
||||
@ -124,11 +124,11 @@ class RRefContext {
|
||||
private:
|
||||
RRefContext(std::shared_ptr<RpcAgent>);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<UserRRef<T>> createUserRRef(
|
||||
std::shared_ptr<UserRRef> createUserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId);
|
||||
const ForkId& forkId,
|
||||
const TypePtr& type);
|
||||
|
||||
void finishForkRequest(const ForkId& forkId, worker_id_t parent);
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <torch/csrc/distributed/rpc/rref_context.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
@ -19,9 +20,10 @@ constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
|
||||
constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
|
||||
constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
|
||||
constexpr int PARENT_IDX = 5; // index of parent in the tuple
|
||||
constexpr int TYPE_IDX = 6; // index of parent in the tuple
|
||||
|
||||
// NB: if more fields are added, make sure this field is also bumped
|
||||
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple
|
||||
constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
|
||||
} // namespace
|
||||
|
||||
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
|
||||
@ -32,8 +34,13 @@ RRefForkData::RRefForkData(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
worker_id_t parent)
|
||||
: ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {}
|
||||
worker_id_t parent,
|
||||
std::string type_str)
|
||||
: ownerId_(ownerId),
|
||||
rrefId_(rrefId),
|
||||
forkId_(forkId),
|
||||
parent_(parent),
|
||||
type_str_(std::move(type_str)) {}
|
||||
|
||||
py::tuple RRefForkData::toPyTuple() const {
|
||||
return py::make_tuple(
|
||||
@ -42,7 +49,8 @@ py::tuple RRefForkData::toPyTuple() const {
|
||||
rrefId_.localId_,
|
||||
forkId_.createdOn_,
|
||||
forkId_.localId_,
|
||||
parent_);
|
||||
parent_,
|
||||
type_str_);
|
||||
}
|
||||
|
||||
RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
|
||||
@ -57,29 +65,39 @@ RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) {
|
||||
const RRefId& forkId = RRefId(
|
||||
t[FORKID_ON_IDX].cast<worker_id_t>(),
|
||||
t[FORKID_ID_IDX].cast<local_id_t>());
|
||||
|
||||
worker_id_t parent = t[PARENT_IDX].cast<worker_id_t>();
|
||||
return RRefForkData(ownerId, rrefId, forkId, parent);
|
||||
const std::string& typeStr = t[TYPE_IDX].cast<std::string>();
|
||||
|
||||
return RRefForkData(ownerId, rrefId, forkId, parent, typeStr);
|
||||
}
|
||||
|
||||
////////////////////////////// RRef /////////////////////////////////////
|
||||
|
||||
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId)
|
||||
: RRefInterface(), ownerId_(ownerId), rrefId_(rrefId) {}
|
||||
RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
||||
: RRefInterface(),
|
||||
ownerId_(ownerId),
|
||||
rrefId_(rrefId),
|
||||
type_(std::move(type)) {}
|
||||
|
||||
RRefForkData RRef::fork() const {
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
return RRefForkData(
|
||||
ownerId_, rrefId_, ctx.genGloballyUniqueId(), ctx.getWorkerId());
|
||||
ownerId_,
|
||||
rrefId_,
|
||||
ctx.genGloballyUniqueId(),
|
||||
ctx.getWorkerId(),
|
||||
type_->str());
|
||||
}
|
||||
|
||||
////////////////////////// UserRRef /////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
UserRRef<T>::UserRRef(
|
||||
UserRRef::UserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId)
|
||||
: RRef(ownerId, rrefId), forkId_(forkId) {
|
||||
const ForkId& forkId,
|
||||
TypePtr type)
|
||||
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
|
||||
// Do nothing,
|
||||
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
||||
// a RREF_FORK_REQUEST message to the owner.
|
||||
@ -87,8 +105,7 @@ UserRRef<T>::UserRRef(
|
||||
// properly notify the owner.
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
UserRRef<T>::~UserRRef() {
|
||||
UserRRef::~UserRRef() {
|
||||
try {
|
||||
RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
|
||||
} catch (const std::exception& ex) {
|
||||
@ -102,80 +119,65 @@ UserRRef<T>::~UserRRef() {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const ForkId& UserRRef<T>::forkId() const {
|
||||
const ForkId& UserRRef::forkId() const {
|
||||
return forkId_;
|
||||
}
|
||||
|
||||
template <>
|
||||
IValue UserRRef<IValue>::toHere() {
|
||||
IValue UserRRef::toHere() {
|
||||
auto agent = RpcAgent::getDefaultRpcAgent();
|
||||
|
||||
// ScriptRRefFetchCall message always carries autograd context id even if
|
||||
// the message itself does not contain any tensor, because the response would
|
||||
// potentially contain tensors.
|
||||
Message msgToSend;
|
||||
|
||||
if (isPyObj()) {
|
||||
msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
|
||||
} else {
|
||||
msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
|
||||
}
|
||||
|
||||
auto futureResponse = autograd::sendMessageWithAutograd(
|
||||
*agent,
|
||||
agent->getWorkerInfo(ownerId_),
|
||||
ScriptRRefFetchCall(ownerId_, rrefId()).toMessage(),
|
||||
std::move(msgToSend),
|
||||
true /* forceGradRecording */);
|
||||
|
||||
const Message& message = futureResponse->wait();
|
||||
MessageType msgType = message.type();
|
||||
auto response = deserializeResponse(message, msgType);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
msgType == MessageType::SCRIPT_RREF_FETCH_RET,
|
||||
"Message type should be SCRIPT_RREF_FETCH_RET.");
|
||||
msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
|
||||
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
||||
"Message type should either be SCRIPT_RREF_FETCH_RET "
|
||||
"or PYTHON_RREF_FETCH_RET");
|
||||
RpcCommandBase& rpc = *response;
|
||||
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
|
||||
return rfr.values().front();
|
||||
if (isPyObj()) {
|
||||
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
|
||||
return jit::toIValue(
|
||||
PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(rfr.values())),
|
||||
PyObjectType::get());
|
||||
} else {
|
||||
auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
|
||||
return rfr.values().front();
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
py::object UserRRef<py::object>::toHere() {
|
||||
auto agent = RpcAgent::getDefaultRpcAgent();
|
||||
|
||||
// PythonRRefFetchCall message always carries autograd context id even if
|
||||
// the message itself does not contain any tensor, because the response would
|
||||
// potentially contain tensors.
|
||||
auto futureResponse = autograd::sendMessageWithAutograd(
|
||||
*agent,
|
||||
agent->getWorkerInfo(ownerId_),
|
||||
PythonRRefFetchCall(ownerId_, rrefId()).toMessage(),
|
||||
true /* forceGradRecording */);
|
||||
|
||||
const Message& message = futureResponse->wait();
|
||||
MessageType msgType = message.type();
|
||||
auto response = deserializeResponse(message, msgType);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
msgType == MessageType::PYTHON_RREF_FETCH_RET,
|
||||
"Message type should be PYTHON_RREF_FETCH_RET.");
|
||||
RpcCommandBase& rpc = *response;
|
||||
auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
|
||||
return PythonRpcHandler::getInstance().deserialize(
|
||||
SerializedPyObj::fromIValues(rfr.values()));
|
||||
}
|
||||
|
||||
template class UserRRef<IValue>;
|
||||
template class UserRRef<py::object>;
|
||||
|
||||
////////////////////////// OwnerRRef /////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
const T& OwnerRRef<T>::getValue() const {
|
||||
const IValue& OwnerRRef::getValue() const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
valueCV_.wait(lock, [this] { return value_.has_value(); });
|
||||
return value_.value();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool OwnerRRef<T>::hasValue() const {
|
||||
bool OwnerRRef::hasValue() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return value_.has_value();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<FutureMessage> OwnerRRef<T>::getFuture() {
|
||||
std::shared_ptr<FutureMessage> OwnerRRef::getFuture() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (future_.get()) {
|
||||
return future_;
|
||||
@ -189,8 +191,7 @@ std::shared_ptr<FutureMessage> OwnerRRef<T>::getFuture() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OwnerRRef<T>::setValue(T&& value) {
|
||||
void OwnerRRef::setValue(IValue&& value) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
value_ = std::move(value);
|
||||
std::shared_ptr<FutureMessage> future;
|
||||
@ -202,9 +203,6 @@ void OwnerRRef<T>::setValue(T&& value) {
|
||||
}
|
||||
}
|
||||
|
||||
template class OwnerRRef<IValue>;
|
||||
template class OwnerRRef<py::object>;
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
||||
|
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
@ -15,7 +16,6 @@ namespace rpc {
|
||||
|
||||
class RRef;
|
||||
class RRefContext;
|
||||
template <typename T>
|
||||
class UserRRef;
|
||||
|
||||
// Represents fork of an RRef to be sent over the wire.
|
||||
@ -27,24 +27,21 @@ struct RRefForkData {
|
||||
const RRefId rrefId_;
|
||||
const ForkId forkId_;
|
||||
const worker_id_t parent_;
|
||||
const std::string type_str_;
|
||||
|
||||
private:
|
||||
friend class RRef;
|
||||
friend class RRefContext;
|
||||
template <typename T>
|
||||
friend class UserRRef;
|
||||
|
||||
RRefForkData(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId_,
|
||||
const ForkId& forkId_,
|
||||
worker_id_t parent);
|
||||
worker_id_t parent,
|
||||
std::string type_str);
|
||||
};
|
||||
|
||||
static_assert(
|
||||
C10_IS_TRIVIALLY_COPYABLE(RRefForkData),
|
||||
"RRefForkData must be trivially copyable");
|
||||
|
||||
// Note [RRef Protocol]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
//
|
||||
@ -207,25 +204,32 @@ class RRef : public RRefInterface {
|
||||
return rrefId_;
|
||||
}
|
||||
|
||||
// returns true if this RRef holds an py::object, false if IValue
|
||||
virtual bool isPyObj() = 0;
|
||||
inline bool isPyObj() {
|
||||
return type_ == PyObjectType::get();
|
||||
}
|
||||
inline const TypePtr type() {
|
||||
return type_;
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class RRefContext;
|
||||
|
||||
RRef(worker_id_t ownerId, const RRefId& rrefId);
|
||||
RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
|
||||
|
||||
RRefForkData fork() const;
|
||||
|
||||
const worker_id_t ownerId_;
|
||||
const RRefId rrefId_;
|
||||
|
||||
// type field to denote the type of the element that the RRef is holding
|
||||
// it could be any TypePtr that JIT support, including PyObjectType
|
||||
const TypePtr type_;
|
||||
};
|
||||
|
||||
// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
|
||||
// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
|
||||
// never owns the real value, the only way to get the value of the ``RRef`` is
|
||||
// to call ``to_here()`` and get a copy..
|
||||
template <typename T>
|
||||
class UserRRef final : public RRef {
|
||||
public:
|
||||
UserRRef(const UserRRef& other) = delete;
|
||||
@ -237,16 +241,12 @@ class UserRRef final : public RRef {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool isPyObj() override {
|
||||
return std::is_same<T, py::object>::value;
|
||||
}
|
||||
|
||||
// Returns the globally unique ForkId of this RRef
|
||||
const ForkId& forkId() const;
|
||||
|
||||
// Get of copy of the value from the ``OwnerRRef``. If the value is not ready
|
||||
// yet, this call will block.
|
||||
T toHere();
|
||||
IValue toHere();
|
||||
|
||||
// Upon destruction, this ``UserRRef`` will tell the owner to deref.
|
||||
~UserRRef() override;
|
||||
@ -254,14 +254,17 @@ class UserRRef final : public RRef {
|
||||
private:
|
||||
friend class RRefContext;
|
||||
|
||||
UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId);
|
||||
UserRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
TypePtr type);
|
||||
|
||||
const ForkId forkId_;
|
||||
};
|
||||
|
||||
// Keep the template only on the derived class because ``RRefContext`` needs to
|
||||
// erase the type on ``RRef`` and keep them in one map.
|
||||
template <typename T>
|
||||
class OwnerRRef final : public RRef {
|
||||
public:
|
||||
OwnerRRef(const OwnerRRef& other) = delete;
|
||||
@ -273,18 +276,14 @@ class OwnerRRef final : public RRef {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool isPyObj() override {
|
||||
return std::is_same<T, py::object>::value;
|
||||
}
|
||||
|
||||
// Get a constant reference of the real value. This method will block if the
|
||||
// value is not ready. This method does not need GIL as it does not create
|
||||
// any new py::object.
|
||||
const T& getValue() const;
|
||||
const IValue& getValue() const;
|
||||
|
||||
// Set the value of this ``OwnerRRef``. This method does not need GIL as it
|
||||
// does not create any new py::object.
|
||||
void setValue(T&& value);
|
||||
void setValue(IValue&& value);
|
||||
|
||||
// Has a value been set?
|
||||
bool hasValue() const;
|
||||
@ -294,15 +293,19 @@ class OwnerRRef final : public RRef {
|
||||
private:
|
||||
friend class RRefContext;
|
||||
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId)
|
||||
: OwnerRRef(ownerId, rrefId, {}) {}
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
|
||||
: OwnerRRef(ownerId, rrefId, type, {}) {}
|
||||
|
||||
OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional<T> value)
|
||||
: RRef(ownerId, rrefId) {
|
||||
OwnerRRef(
|
||||
worker_id_t ownerId,
|
||||
const RRefId& rrefId,
|
||||
TypePtr type,
|
||||
c10::optional<IValue> value)
|
||||
: RRef(ownerId, rrefId, std::move(type)) {
|
||||
value_ = std::move(value);
|
||||
}
|
||||
|
||||
c10::optional<T> value_;
|
||||
c10::optional<IValue> value_;
|
||||
mutable std::mutex mutex_;
|
||||
mutable std::condition_variable valueCV_;
|
||||
std::shared_ptr<FutureMessage> future_;
|
||||
|
Reference in New Issue
Block a user