Don't run user function until all UserRRefs in the args are confirmed (#34497)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497

Use a thread_local table to intercept UserRRefs created during user
function args deserialization, and then wait for confirmations of
those UserRRefs before launching the given user function.

Differential Revision: D20347464

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
This commit is contained in:
Shen Li
2020-03-16 18:23:45 -07:00
committed by Facebook GitHub Bot
parent d876fef743
commit 422e348619
26 changed files with 590 additions and 122 deletions

View File

@ -27,7 +27,11 @@ class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
// Returns true if this is the ``OwnerRRef``
virtual bool isOwner() const = 0;
// Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been
// confirmed by its owner.
virtual bool confirmedByOwner() const = 0;
virtual const TypePtr type() const = 0;
};
}
}

View File

@ -111,6 +111,7 @@ white_list = [
('aten::_linear_prepack', datetime.date(2020, 4, 1)),
('aten::_conv2d_packed', datetime.date(2020, 4, 1)),
('aten::_conv2d_prepack', datetime.date(2020, 4, 1)),
('aten::confirmed_by_owner', datetime.date(2020, 3, 17)),
]

View File

@ -347,6 +347,8 @@ def add_torch_libs():
"torch/csrc/distributed/rpc/python_functions.cpp",
"torch/csrc/distributed/rpc/python_rpc_handler.cpp",
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
"torch/csrc/distributed/rpc/unpickled_python_call.cpp",
"torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp",
"torch/csrc/jit/python/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",

View File

@ -258,6 +258,8 @@ if (USE_DISTRIBUTED)
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/unpickled_python_call.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/unpickled_python_remote_call.cpp
${TORCH_SRC_DIR}/csrc/jit/runtime/register_distributed_ops.cpp
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)

View File

@ -166,6 +166,11 @@ RpcCommandBase& RpcWithAutograd::wrappedRpc() {
return *wrappedRpc_;
}
void RpcWithAutograd::setWrappedRpc(
std::unique_ptr<RpcCommandBase> wrappedRpc) {
wrappedRpc_ = std::move(wrappedRpc);
}
std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
return std::move(wrappedRpc_);

View File

@ -42,6 +42,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
RpcCommandBase& wrappedRpc();
void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
// Message type of the wrapped RPC.

View File

@ -168,6 +168,14 @@ PyObject* rpc_init(PyObject* /* unused */) {
Returns whether or not the current node is the owner of this
``RRef``.
)")
.def(
"confirmed_by_owner",
&PyRRef::confirmedByOwner,
R"(
Returns whether this ``RRef`` has been confirmed by the owner.
``OwnerRRef`` always returns true, while ``UserRRef`` only
returns true when the owner knowns about this ``UserRRef``.
)")
.def(
// not releasing GIL here to avoid context switch on getters
"owner",

View File

@ -112,21 +112,19 @@ void Message::setId(int64_t id) {
id_ = id;
}
Message createExceptionResponse(
const Message& request,
const std::exception& e) {
return createExceptionResponse(request, e.what());
Message createExceptionResponse(int64_t requestId, const std::exception& e) {
return createExceptionResponse(requestId, e.what());
}
Message createExceptionResponse(
const Message& request,
int64_t requestId,
const std::string& exceptionStr) {
std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
return Message(
std::move(payload),
std::vector<torch::Tensor>(),
MessageType::EXCEPTION,
request.id());
requestId);
}
} // namespace rpc

View File

@ -118,13 +118,12 @@ class TORCH_API Message final {
// Create a response Message with an exception for the provided request message.
// The exception string representation will be used as the message's payload.
TORCH_API Message
createExceptionResponse(const Message& request, const std::exception& e);
createExceptionResponse(int64_t requestId, const std::exception& e);
// Create a response Message with an exception type for the provided request
// message. The passed in string will be used as the created message's payload
TORCH_API Message createExceptionResponse(
const Message& request,
const std::string& exceptionStr);
TORCH_API Message
createExceptionResponse(int64_t requestId, const std::string& exceptionStr);
typedef torch::utils::Future<Message> FutureMessage;

View File

@ -416,7 +416,7 @@ void ProcessGroupAgent::enqueueSend(SendWork work) {
"Encountered exception in ProcessGroupAgent::enqueueSend: ",
e.what());
auto exceptionMsg =
rpc::createExceptionResponse(work.message_, errorStr);
rpc::createExceptionResponse(work.message_.id(), errorStr);
if (work.message_.isRequest()) {
markFutureWithError(exceptionMsg);
} else if (work.message_.isResponse()) {
@ -455,7 +455,7 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
send(
work.from_,
createExceptionResponse(
message, futureResponse->error()->what()));
message.id(), futureResponse->error()->what()));
}
} else {
++serverActiveAsyncCalls_;
@ -685,8 +685,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
"RPC ran for more than ",
timedOutFuture.timeout_.count(),
" milliseconds and timed out.");
const auto exceptionMsg = createExceptionResponse(
Message({}, {}, MessageType::EXCEPTION), errorStr);
const auto exceptionMsg = createExceptionResponse(-1, errorStr);
if (!timedOutFuture.future_->hasError()) {
--clientActiveCalls_;
timedOutFuture.future_->setError(std::string(

View File

@ -113,6 +113,10 @@ bool PyRRef::isOwner() const {
return rref_->isOwner();
}
bool PyRRef::confirmedByOwner() const {
return rref_->confirmedByOwner();
}
WorkerInfo PyRRef::owner() const {
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
}

View File

@ -16,6 +16,7 @@ class PyRRef {
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
bool isOwner() const;
bool confirmedByOwner() const;
WorkerInfo owner() const;
py::object toHere();
py::object localValue();

View File

@ -9,45 +9,14 @@ namespace rpc {
using namespace torch::distributed::autograd;
namespace {
// When request message has autograd info, processMessage() will set up valid
// current context id properly. This struct is used to clean up current context
// id after processMessage() is done.
struct ClearAutogradContextGuard {
ClearAutogradContextGuard() = default;
~ClearAutogradContextGuard() {
clear();
}
void clear() {
auto& autogradContainer = DistAutogradContainer::getInstance();
autogradContainer.clearCurrentContext();
}
};
} // anonymous namespace
std::shared_ptr<FutureMessage> RequestCallback::operator()(
Message& request) const {
// For a recv thread, current context id should be invalid outside
// processMessage().
ClearAutogradContextGuard guard;
try {
return processMessage(request);
} catch (std::exception& e) {
LOG(ERROR) << "Received error while processing request type "
<< request.type() << ": " << e.what();
// Adding node information to the error here since all processed RPC
// requests should be going through this function.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
e.what());
return std::make_shared<FutureMessage>(
createExceptionResponse(request, errorMsg));
}
// NB: cannot clear autograd context id here because the processMessage method
// might pause waiting for all RRefs in the arguments to be confirmed by their
// owners and resumne processing in a different thread. Hence, the
// thread_local context id needs to be set and cleared in the thread that
// indeed carries out the processing logic.
return processMessage(request);
}
} // namespace rpc

View File

@ -20,6 +20,8 @@
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/python/pybind_utils.h>
@ -29,13 +31,86 @@ namespace rpc {
using namespace torch::distributed::autograd;
std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
namespace {
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
RpcCommandBase& rpc,
MessageType messageType,
const MessageType& messageType) {
switch (messageType) {
case MessageType::PYTHON_CALL: {
auto& pc = static_cast<PythonCall&>(rpc);
return std::make_unique<UnpickledPythonCall>(pc.serializedPyObj());
}
case MessageType::PYTHON_REMOTE_CALL: {
auto& prc = static_cast<PythonRemoteCall&>(rpc);
return std::make_unique<UnpickledPythonRemoteCall>(
prc.serializedPyObj(), prc.retRRefId(), prc.retForkId());
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
// Deserialize the wrapped RPC if it contains Python UDF
auto& rwa = static_cast<RpcWithAutograd&>(rpc);
auto& wrappedRpc = rwa.wrappedRpc();
auto pythonRpc = deserializePythonRpcCommandReference(
wrappedRpc, rwa.wrappedMessageType());
if (pythonRpc) {
rwa.setWrappedRpc(std::move(pythonRpc));
}
return nullptr;
}
default: {
return nullptr;
}
}
}
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
std::unique_ptr<RpcCommandBase> rpc,
const MessageType& messageType) {
auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
}
// When request message has autograd info, processMessage() will set up valid
// current context id properly. This struct is used to clean up current context
// id after processMessage() is done.
struct ClearAutogradContextGuard {
ClearAutogradContextGuard() = default;
~ClearAutogradContextGuard() {
clear();
}
void clear() {
auto& autogradContainer = DistAutogradContainer::getInstance();
autogradContainer.clearCurrentContext();
}
};
} // anonymous namespace
Message RequestCallbackImpl::handleError(
const std::exception& e,
const MessageType messageType,
int64_t messageId) const {
auto wrap = [messageId](Message m) {
LOG(ERROR) << "Received error while processing request type " << messageType
<< ": " << e.what();
// Adding node information to the error here since all processed RPC
// requests should be going through this function.
std::string errorMsg = c10::str(
"Error on Node ",
DistAutogradContainer::getInstance().getWorkerId(),
": ",
e.what());
return createExceptionResponse(messageId, errorMsg);
}
void RequestCallbackImpl::processRpc(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const std::shared_ptr<FutureMessage>& responseFuture) const {
auto markComplete = [messageId, &responseFuture](Message m) {
m.setId(messageId);
return std::make_shared<FutureMessage>(std::move(m));
responseFuture->markCompleted(std::move(m));
};
// TODO: RpcCommandBase should have an abstract execute() method that we can
@ -65,22 +140,22 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
"TorchScript function should be a single IValue, got a vector of "
"size ",
stack.size());
return wrap(std::move(ScriptResp(std::move(stack.front()))).toMessage());
markComplete(std::move(ScriptResp(std::move(stack.front()))).toMessage());
return;
}
case MessageType::PYTHON_CALL: {
auto& pyCall = static_cast<PythonCall&>(rpc);
auto& upc = static_cast<UnpickledPythonCall&>(rpc);
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
std::shared_ptr<SerializedPyObj> serializedPyObj = nullptr;
{
pybind11::gil_scoped_acquire ag;
auto pythonUdf = pythonRpcHandler.deserialize(pyCall.serializedPyObj());
serializedPyObj =
std::make_shared<SerializedPyObj>(pythonRpcHandler.serialize(
pythonRpcHandler.runPythonUdf(std::move(pythonUdf))));
pythonRpcHandler.runPythonUdf(std::move(upc).movePythonUdf())));
}
return wrap(
markComplete(
std::move(PythonResp(std::move(*serializedPyObj))).toMessage());
return;
}
case MessageType::SCRIPT_REMOTE_CALL: {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
@ -135,13 +210,14 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
}
return wrap(RemoteRet(rrefId, forkId).toMessage());
markComplete(RemoteRet(rrefId, forkId).toMessage());
return;
}
case MessageType::PYTHON_REMOTE_CALL: {
auto& prc = static_cast<PythonRemoteCall&>(rpc);
auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
auto rrefId = RRefId::fromIValue(prc.retRRefId());
auto forkId = ForkId::fromIValue(prc.retForkId());
const auto& rrefId = uprc.rrefId();
const auto& forkId = uprc.forkId();
auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get());
@ -150,9 +226,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
IValue py_ivalue;
{
pybind11::gil_scoped_acquire ag;
auto pythonUdf = pythonRpcHandler.deserialize(prc.serializedPyObj());
py_ivalue = jit::toIValue(
pythonRpcHandler.runPythonUdf(std::move(pythonUdf)),
pythonRpcHandler.runPythonUdf(std::move(uprc).movePythonUdf()),
PyObjectType::get());
}
@ -169,32 +244,33 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
}
return wrap(RemoteRet(rrefId, forkId).toMessage());
markComplete(RemoteRet(rrefId, forkId).toMessage());
return;
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
c10::intrusive_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
if (rref->hasValue()) { // optional fast-path
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
markComplete(ScriptRRefFetchRet({rref->getValue()}).toMessage());
return;
} else {
auto whenValueSet = rref->getFuture();
// Our response is satisfied when the rpcs come back.
whenValueSet->addCallback(
[responseFuture, messageId, rref](
const rpc::Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (!error) {
Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} else {
responseFuture->setError(error->what());
}
});
}
auto whenValueSet = rref->getFuture();
auto responseFuture = std::make_shared<FutureMessage>();
// Our response is satisfied when the rpcs come back.
whenValueSet->addCallback(
[responseFuture, messageId, rref](
const rpc::Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (!error) {
Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} else {
responseFuture->setError(error->what());
}
});
return responseFuture;
return;
}
case MessageType::PYTHON_RREF_FETCH_CALL: {
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
@ -209,12 +285,12 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
}
SerializedPyObj result =
PythonRpcHandler::getInstance().serialize(pyValue);
return wrap(
markComplete(
PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
return;
}
auto whenValueSet = rref->getFuture();
auto responseFuture = std::make_shared<FutureMessage>();
// Our response is satisfied when the rpcs come back.
whenValueSet->addCallback(
@ -238,7 +314,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
responseFuture->setError(error->what());
}
});
return responseFuture;
return;
}
case MessageType::RREF_USER_DELETE: {
auto& rud = static_cast<RRefUserDelete&>(rpc);
@ -248,19 +324,22 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
pybind11::gil_scoped_acquire ag;
deletedRRef.reset();
}
return wrap(std::move(RRefAck()).toMessage());
markComplete(std::move(RRefAck()).toMessage());
return;
}
case MessageType::RREF_CHILD_ACCEPT: {
auto& rca = static_cast<RRefChildAccept&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.delPendingChild(rca.forkId());
return wrap(std::move(RRefAck()).toMessage());
markComplete(std::move(RRefAck()).toMessage());
return;
}
case MessageType::RREF_FORK_REQUEST: {
auto& rfr = static_cast<RRefForkRequest&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.addForkOfOwner(rfr.rrefId(), rfr.forkId());
return wrap(RRefAck().toMessage());
markComplete(RRefAck().toMessage());
return;
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
@ -283,13 +362,16 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
// Process the original RPC.
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
// Make an overall future for the wrapped response.
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>();
// Kick off processing for the nested future and get a Future<T> to the
// result.
auto wrappedRpcResponseFuture = processRpc(
rpcWithAutograd.wrappedRpc(), wrappedMessageType, messageId);
processRpc(
rpcWithAutograd.wrappedRpc(),
wrappedMessageType,
messageId,
wrappedRpcResponseFuture);
// Make an overall future for the wrapped response.
auto responseFuture = std::make_shared<rpc::FutureMessage>();
auto fromWorkerId = rpcWithAutograd.fromWorkerId();
// The original future needs to be marked as completed when the wrapped
// one completes, with the autograd context information wrapped.
@ -309,7 +391,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
responseFuture->markCompleted(std::move(msg));
}
});
return responseFuture;
return;
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
@ -328,8 +410,6 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
// Attach the gradients to the send function.
sendFunction->setGrads(gradientsCall.getGrads());
auto responseFuture = std::make_shared<rpc::FutureMessage>();
// Now execute the autograd graph using the "distributed engine."
auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
autogradContext, sendFunction, gradientsCall.retainGraph());
@ -347,7 +427,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
responseFuture->setError(error->what());
}
});
return responseFuture;
return;
};
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
@ -358,7 +438,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
// notified to clean up their context.
DistAutogradContainer::getInstance().releaseContextIfPresent(
cleanupContextId);
return wrap(std::move(CleanupAutogradContextResp()).toMessage());
markComplete(std::move(CleanupAutogradContextResp()).toMessage());
return;
}
default: {
TORCH_INTERNAL_ASSERT(
@ -369,8 +450,42 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
std::shared_ptr<FutureMessage> RequestCallbackImpl::processMessage(
Message& request) const {
std::unique_ptr<RpcCommandBase> rpc = deserializeRequest(request);
return processRpc(*rpc, request.type(), request.id());
// We need two futures here because it could pause twice when processing a
// RPC message:
// 1) waiting for all RRefs in the arguments to become confirmed;
// 2) waiting for processRpc to finish.
auto retFuture = std::make_shared<FutureMessage>();
auto& rrefContext = RRefContext::getInstance();
try {
rrefContext.recordThreadLocalPendingRRefs();
std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
deserializeRequest(request), request.type());
auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
rrefsReadyFuture->addCallback(
[this,
retFuture,
// std::function must be copyable, hence hae to cast the unique_ptr to
// a shared_ptr here.
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
messageType = request.type(),
id = request.id()](
const bool& /*unused*/,
const c10::optional<utils::FutureError>& /*unused*/) {
try {
// For a recv thread, current context id should be invalid outside
// processMessage().
ClearAutogradContextGuard guard;
processRpc(*rpc, messageType, id, retFuture);
} catch (std::exception& e) {
retFuture->markCompleted(handleError(e, messageType, id));
}
});
} catch (std::exception& e) {
retFuture->markCompleted(handleError(e, request.type(), request.id()));
rrefContext.clearRecordedPendingRRefsOnError();
}
return retFuture;
}
} // namespace rpc

View File

@ -14,9 +14,15 @@ class TORCH_API RequestCallbackImpl : public RequestCallback {
Message& request) const override;
private:
std::shared_ptr<FutureMessage> processRpc(
void processRpc(
RpcCommandBase& rpc,
MessageType messageType,
const MessageType& messageType,
const int64_t messageId,
const std::shared_ptr<FutureMessage>& retFutureMessagge) const;
Message handleError(
const std::exception& e,
const MessageType messageType,
int64_t messageId) const;
};

View File

@ -7,6 +7,10 @@ namespace torch {
namespace distributed {
namespace rpc {
thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>>
RRefContext::userTable_;
thread_local bool RRefContext::recording = false;
namespace callback {
void confirmPendingUser(
const rpc::Message& message,
@ -438,38 +442,104 @@ void RRefContext::addPendingUser(
const c10::intrusive_ptr<RRef>& rref) {
TORCH_INTERNAL_ASSERT(
!rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
auto state = std::make_shared<PendingUserState>(rref);
if (recording) {
// adding and waiting for pending users are guaranteed to be called from the
// same thread, but deleting pending users will be called from another
// thread. As the delPendingUser will not be able to access the same
// thread_local variable, we cannot address this problem by making
// pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share
// the same PendingUserState shared_ptr.
userTable_.push_back(state);
}
std::lock_guard<std::mutex> lock(mutex_);
TORCH_INTERNAL_ASSERT(
pendingUsers_.find(forkId) == pendingUsers_.end(),
"Inconsistent states: attempt to add the same UserRRef twice.");
pendingUsers_[forkId] = rref;
pendingUsers_.emplace(
std::piecewise_construct,
std::forward_as_tuple(forkId),
std::forward_as_tuple(state));
}
void RRefContext::delPendingUser(const ForkId& forkId) {
c10::intrusive_ptr<RRef> deletedUser;
std::shared_ptr<PendingUserState> deletedState = nullptr;
{
std::lock_guard<std::mutex> lock(mutex_);
auto iter = pendingUsers_.find(forkId);
TORCH_INTERNAL_ASSERT(
iter != pendingUsers_.end(),
"Inconsistent states: attempt to delete a non-exist UserRRef.");
// There are two reasons for keeping the deleted PendingUserState alive
// until exiting the critical section.
// (1) Since this UserRRef is removed from the map, the refcount of this
// UserRRef could reach to 0. So the resource destructor
// (`release_resources()`) might be called, in which the lock is
// acquired again. Hence, it must be destructed with the lock released.
// To meet this constraint, we intentionally create a temporary pointer
// to increase the refcount of the deleted PendingUserState, extending
// its lifetime untill lock released.
// (2) Since #34497, a user function only runs after all RRefs in the
// arguments are confirmed by their owners, which is done by adding the
// RPC processing logic as a callback to the UserRRef ready future. So,
// calling `confirm` on the PendingUserState could trigger pending user
// functions, which might in turn acquire the lock in RRefContext.
// Hence, we must release the lock to prevent deadlock.
// NB: Another option is to use reentrant lock. However, it is better for
// the developers to fully understand the locking behavior instead of
// hiding the subtle logic using a reentrant lock.
deletedState = iter->second; // Increase refcount
confirmedUsers_.emplace(
std::piecewise_construct,
std::forward_as_tuple(forkId),
std::forward_as_tuple(iter->second));
// Since this UserRRef is removed from the map,
// the refcount of this UserRRef could reach to 0,
// so the "destructor", `release_resources()`, might be called,
// in which the lock is acquired again.
// So it must be destructed with the lock released.
// Meet this constraint by creating a temporary pointer to increase the
// refcount, extending its lifetime untill lock released.
deletedUser = iter->second; // Increase refcount.
std::forward_as_tuple(iter->second->rref_));
pendingUsers_.erase(iter); // Decrease refcount.
}
deletedState->confirm();
deleteAllUsersCV_.notify_all();
deletedUser.reset(); // Decrease refcount.
deletedState.reset(); // Decrease refcount.
}
void RRefContext::recordThreadLocalPendingRRefs() {
TORCH_INTERNAL_ASSERT(
userTable_.empty(),
"User RRef Table should be empty when start recording");
recording = true;
}
std::shared_ptr<torch::utils::Future<bool>> RRefContext::
waitForThreadLocalPendingRRefs() {
auto future = std::make_shared<torch::utils::Future<bool>>();
if (userTable_.empty()) {
future->markCompleted(true);
} else {
auto remainingRRefs =
std::make_shared<std::atomic<uint64_t>>(userTable_.size());
for (auto& state : userTable_) {
state->future_.addCallback(
[future, remainingRRefs](
const bool& /* unused */,
const c10::optional<utils::FutureError>& /* unused */) {
auto localCount = remainingRRefs->fetch_sub(1);
if (localCount == 1) {
future->markCompleted(true);
}
});
}
userTable_.clear();
}
recording = false;
return future;
}
void RRefContext::clearRecordedPendingRRefsOnError() {
userTable_.clear();
recording = false;
}
void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {

View File

@ -5,6 +5,7 @@
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/utils/future.h>
#include <atomic>
@ -147,6 +148,27 @@ class TORCH_API RRefContext {
const c10::intrusive_ptr<RRef>& rref);
void delPendingUser(const ForkId& forkId);
// Start recroding new pending UserRRefs. All pending UserRRefs introduced
// after this point will be put into the thread_local userTable_, which will
// then be consumed and cleared in waitForThreadLocalPendingRRefs().
void recordThreadLocalPendingRRefs();
// End recording new pending UserRRefs, and clear the thread_local userTable_.
// Returns a Future which will be marked as completed when all pending
// UserRRefs in the current userTable_ are confirmed by their owners. The bool
// value in the Future is unused.
// This method is useful to make sure RRefs in user function arguments are
// confirmed before launching user code.
// NB: Callers of this method does not need to keep the returned Future alive,
// because this Future is already captured in callbacks of the
// PendingUserState. If there is no pending UserRRefs, this method returns a
// completed future.
std::shared_ptr<torch::utils::Future<bool>> waitForThreadLocalPendingRRefs();
// Only call this function when there are errors during a recording session,
// and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked
// properly.
// TODO: make this a context guard
void clearRecordedPendingRRefsOnError();
void delUser(
const worker_id_t owner,
const RRefId& rrefId,
@ -156,6 +178,20 @@ class TORCH_API RRefContext {
std::unordered_map<std::string, std::string> getDebugInfo();
private:
struct PendingUserState {
PendingUserState(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {}
inline void confirm() {
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm();
future_.markCompleted(true);
}
c10::intrusive_ptr<RRef> rref_;
// Use Future.wait() and Future.markCompleted() to block and unblock user
// functions. The bool value wrapped by the future_ is not used.
torch::utils::Future<bool> future_;
};
RRefContext(std::shared_ptr<RpcAgent>);
c10::intrusive_ptr<UserRRef> createUserRRef(
@ -209,7 +245,7 @@ class TORCH_API RRefContext {
// It can be used or shared, but cannot be deleted, and hence kept alive
// in this map. A message of type RREF_USER_ACCEPT will move the
// corresponding RRef from pendingUsers_ map to confirmedUsers_ map.
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash>
pendingUsers_;
// UserRRefs are added into this map when it is confirmed by the owner.
// When destroying RRefContext this map helps to find local UserRRefs
@ -229,6 +265,34 @@ class TORCH_API RRefContext {
std::mutex destroyedMutex_;
bool destroyed_;
// Thread local states to keep UserRRefs deserialized from user function
// arguments.
static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_;
// A flag indicating whether subsequently created UserRRefs should be added to
// the thread_local userTable_. The flag is set to true before serializing
// RPC arguments and then set to false before running the corresponding
// user code. See addPendingUser and delPendingUser for more details.
// NB: The reason for having this flag is because addPendingUser are called in
// two cases, and we only want to track the 2nd case.
// (1) RRef as the return value: when calling rpc.remote, the UserRRef on the
// caller side is added to the context using addPendingUser.
// (2) RRef as an argument: When running an RPC using RRefs as arguments, the
// RRef is forwarded to the callee as new UserRRefs (if the callee is not
// the owner). In this case, we block running the user function until all
// UserRRefs are confirmed by the owner.
// This contract gurantees that no UserRRefs can be used remotely without
// confirmation. Note that, however, the UserRRef created by rpc.remote can
// still be passed to local functions as arguments and used there. This is by
// design, because this feature is especially useful when, say a master node
// creates multiple UserRRefs in a loop and then shares them with other nodes.
// Blocking every iteration in the loop until RRefs are confirmed will slow
// this down. This nuance on UserRRef can be interpreted as we only make
// exceptions for UserRRef creators. And using the UserRRef on its creator
// without confirmation is OK, because the creator would either call to_here
// or forward the UserRRef, and both would then require confirmations from the
// owner.
static thread_local bool recording;
};
} // namespace rpc

View File

@ -70,7 +70,9 @@ UserRRef::UserRRef(
const RRefId& rrefId,
const ForkId& forkId,
TypePtr type)
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
: RRef(ownerId, rrefId, std::move(type)),
forkId_(forkId),
confirmedByOwner_(false) {
// Do nothing,
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
// a RREF_FORK_REQUEST message to the owner.

View File

@ -257,6 +257,10 @@ class TORCH_API UserRRef final : public RRef {
return false;
}
inline bool confirmedByOwner() const override {
return confirmedByOwner_;
}
// Returns the globally unique ForkId of this RRef
const ForkId& forkId() const;
@ -280,6 +284,9 @@ class TORCH_API UserRRef final : public RRef {
friend class RRefContext;
RRefForkData fork() const override;
inline void confirm() {
confirmedByOwner_ = true;
}
const ForkId forkId_;
@ -289,6 +296,8 @@ class TORCH_API UserRRef final : public RRef {
// proactive cleanup on RPC graceful shutdown.
std::mutex deletedOnOwnerMutex_;
bool deletedOnOwner_{false};
// Indicating whether this UserRRef has been confirmed by its owner.
std::atomic<bool> confirmedByOwner_;
};
// Keep the template only on the derived class because ``RRefContext`` needs to
@ -316,6 +325,12 @@ class TORCH_API OwnerRRef final : public RRef {
return true;
}
// OwnerRRef is always confirmed, while UserRRef is only confirmed when the
// owner knows about it.
inline bool confirmedByOwner() const override {
return true;
}
// Get a constant reference of the real value. This method will block if the
// value is not ready. This method does not need GIL as it does not create
// any new py::object.

View File

@ -0,0 +1,28 @@
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
namespace torch {
namespace distributed {
namespace rpc {
UnpickledPythonCall::UnpickledPythonCall(
const SerializedPyObj& serializedPyObj) {
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
pybind11::gil_scoped_acquire ag;
pythonUdf_ = pythonRpcHandler.deserialize(serializedPyObj);
}
Message UnpickledPythonCall::toMessage() && {
TORCH_INTERNAL_ASSERT(
false, "UnpickledPythonCall does not support toMessage().");
}
py::object UnpickledPythonCall::movePythonUdf() && {
return std::move(pythonUdf_);
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,33 @@
#pragma once
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
// This class converts the content in a PythonCall into py::object. This is a
// helper class to make sure that all arguments deserialization is done before
// entering RequestCallbackImpl::processRpc(...), so that the deserialization
// related logic can be carried out in one spot instead of scattered in multiple
// places for different message types.
// NB: The reason for not consolidating class into PythonCall is because
// PythonCall is a libtorch type which should not depend on Python types.
class TORCH_API UnpickledPythonCall : public RpcCommandBase {
public:
explicit UnpickledPythonCall(const SerializedPyObj& serializedPyObj);
// toMessage() method is not implemented, as objects of this class should
// never be directly converted into a Message object.
Message toMessage() && override;
py::object movePythonUdf() &&;
private:
py::object pythonUdf_;
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,28 @@
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
namespace torch {
namespace distributed {
namespace rpc {
UnpickledPythonRemoteCall::UnpickledPythonRemoteCall(
const SerializedPyObj& serializedPyObj,
const at::IValue& rrefId,
const at::IValue& forkId)
: UnpickledPythonCall(serializedPyObj),
rrefId_(RRefId::fromIValue(rrefId)),
forkId_(ForkId::fromIValue(forkId)) {}
const RRefId& UnpickledPythonRemoteCall::rrefId() const {
return rrefId_;
}
const ForkId& UnpickledPythonRemoteCall::forkId() const {
return forkId_;
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,36 @@
#pragma once
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace distributed {
namespace rpc {
// This class converts the content in a PythonRemoteCall into py::object. This
// is a helper class to make sure that all arguments deserialization is done
// before entering RequestCallbackImpl::processRpc(...), so that the
// deserialization related logic can be carried out in one spot instead of
// scattered in multiple places for different message types.
// NB: The reason for not consolidating class into PythonRemoteCall is because
// PythonRemoteCall is a libtorch type which should not depend on Python types.
class TORCH_API UnpickledPythonRemoteCall final : public UnpickledPythonCall {
public:
explicit UnpickledPythonRemoteCall(
const SerializedPyObj& serializedPyObj,
const at::IValue& retRRefId,
const at::IValue& retForkId);
const RRefId& rrefId() const;
const ForkId& forkId() const;
private:
RRefId rrefId_;
ForkId forkId_;
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -63,6 +63,14 @@ RegisterOperators reg_rpc_ops({
return 0;
},
aliasAnalysisFromSchema()),
Operator(
"aten::confirmed_by_owner(RRef(t) self) -> bool",
[](Stack& stack) {
auto rref = pop(stack).toRRef();
push(stack, rref->confirmedByOwner());
return 0;
},
aliasAnalysisFromSchema()),
Operator(
prim::rpc_async,
[](const Node* node) -> Operation {

View File

@ -97,7 +97,7 @@ class LocalRRefTest(RpcAgentTestFixture):
return
# Create a local RRef<MyModuleInterface>.
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
ret = rref_script_module.to_here().forward()
self.assertEqual(ret, torch.ones(self.rank))
@ -535,6 +535,11 @@ def rref_script_annotation(rref_var):
return rref_python_annotation(rref_var).to_here()
@torch.jit.script
def script_check_rref_confirmed(rref):
# type: (RRef[Tensor]) -> bool
return rref.confirmed_by_owner()
@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed rpc package does not support python2"
)
@ -673,3 +678,34 @@ class JitRpcTest(LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixture):
res = rref_script_annotation(rref_var)
self.assertEqual(res, torch.ones(2, 2) + 1)
def _create_rref(self):
owner_rank = (self.rank + 2) % self.world_size
return rpc.remote(
"worker{}".format(owner_rank),
torch.add,
args=(torch.zeros(2, 2), 1)
)
@dist_init
def test_user_rrefs_confirmed(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
script_check_rref_confirmed,
args=(rref,)
)
self.assertEqual(ret, True)
@dist_init
def test_user_rrefs_confirmed_remote(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret_rref = rpc.remote(
"worker{}".format(dst_rank),
script_check_rref_confirmed,
args=(rref,)
)
self.assertEqual(ret_rref.to_here(), True)

View File

@ -259,6 +259,9 @@ def clear_global_rref():
global_rref = None
def check_rref_confirmed(rref):
return rref.confirmed_by_owner()
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -1789,3 +1792,33 @@ class RpcTest(RpcAgentTestFixture):
# Sending to self should fail too.
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
rpc.rpc_sync(worker_name(self.rank), torch.add, args=(t1, t2))
def _create_rref(self):
owner_rank = (self.rank + 2) % self.world_size
return rpc.remote(
"worker{}".format(owner_rank),
torch.add,
args=(torch.zeros(2, 2), 1)
)
@dist_init
def test_user_rrefs_confirmed(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
check_rref_confirmed,
args=(rref,)
)
self.assertEqual(ret, True)
@dist_init
def test_user_rrefs_confirmed_remote(self):
dst_rank = (self.rank + 1) % self.world_size
rref = self._create_rref()
ret_rref = rpc.remote(
"worker{}".format(dst_rank),
check_rref_confirmed,
args=(rref,)
)
self.assertEqual(ret_rref.to_here(), True)