mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Don't run user function until all UserRRefs in the args are confirmed (#34497)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497 Use a thread_local table to intercept UserRRefs created during user function args deserialization, and then wait for confirmations of those UserRRefs before launching the given user function. Differential Revision: D20347464 Test Plan: Imported from OSS Pulled By: mrshenli fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d876fef743
commit
422e348619
@ -27,7 +27,11 @@ class C10_EXPORT RRefInterface : public c10::intrusive_ptr_target {
|
||||
// Returns true if this is the ``OwnerRRef``
|
||||
virtual bool isOwner() const = 0;
|
||||
|
||||
// Returns true if this is an ``OwnerRRef`` or if this ``UserRRef`` has been
|
||||
// confirmed by its owner.
|
||||
virtual bool confirmedByOwner() const = 0;
|
||||
|
||||
virtual const TypePtr type() const = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -111,6 +111,7 @@ white_list = [
|
||||
('aten::_linear_prepack', datetime.date(2020, 4, 1)),
|
||||
('aten::_conv2d_packed', datetime.date(2020, 4, 1)),
|
||||
('aten::_conv2d_prepack', datetime.date(2020, 4, 1)),
|
||||
('aten::confirmed_by_owner', datetime.date(2020, 3, 17)),
|
||||
]
|
||||
|
||||
|
||||
|
@ -347,6 +347,8 @@ def add_torch_libs():
|
||||
"torch/csrc/distributed/rpc/python_functions.cpp",
|
||||
"torch/csrc/distributed/rpc/python_rpc_handler.cpp",
|
||||
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
|
||||
"torch/csrc/distributed/rpc/unpickled_python_call.cpp",
|
||||
"torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp",
|
||||
"torch/csrc/jit/python/init.cpp",
|
||||
"torch/csrc/jit/passes/inline_fork_wait.cpp",
|
||||
"torch/csrc/jit/passes/onnx.cpp",
|
||||
|
@ -258,6 +258,8 @@ if (USE_DISTRIBUTED)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/unpickled_python_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/unpickled_python_remote_call.cpp
|
||||
${TORCH_SRC_DIR}/csrc/jit/runtime/register_distributed_ops.cpp
|
||||
)
|
||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
|
||||
|
@ -166,6 +166,11 @@ RpcCommandBase& RpcWithAutograd::wrappedRpc() {
|
||||
return *wrappedRpc_;
|
||||
}
|
||||
|
||||
void RpcWithAutograd::setWrappedRpc(
|
||||
std::unique_ptr<RpcCommandBase> wrappedRpc) {
|
||||
wrappedRpc_ = std::move(wrappedRpc);
|
||||
}
|
||||
|
||||
std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
|
||||
TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
|
||||
return std::move(wrappedRpc_);
|
||||
|
@ -42,6 +42,8 @@ class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
|
||||
|
||||
RpcCommandBase& wrappedRpc();
|
||||
|
||||
void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
|
||||
|
||||
std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
|
||||
|
||||
// Message type of the wrapped RPC.
|
||||
|
@ -168,6 +168,14 @@ PyObject* rpc_init(PyObject* /* unused */) {
|
||||
Returns whether or not the current node is the owner of this
|
||||
``RRef``.
|
||||
)")
|
||||
.def(
|
||||
"confirmed_by_owner",
|
||||
&PyRRef::confirmedByOwner,
|
||||
R"(
|
||||
Returns whether this ``RRef`` has been confirmed by the owner.
|
||||
``OwnerRRef`` always returns true, while ``UserRRef`` only
|
||||
returns true when the owner knowns about this ``UserRRef``.
|
||||
)")
|
||||
.def(
|
||||
// not releasing GIL here to avoid context switch on getters
|
||||
"owner",
|
||||
|
@ -112,21 +112,19 @@ void Message::setId(int64_t id) {
|
||||
id_ = id;
|
||||
}
|
||||
|
||||
Message createExceptionResponse(
|
||||
const Message& request,
|
||||
const std::exception& e) {
|
||||
return createExceptionResponse(request, e.what());
|
||||
Message createExceptionResponse(int64_t requestId, const std::exception& e) {
|
||||
return createExceptionResponse(requestId, e.what());
|
||||
}
|
||||
|
||||
Message createExceptionResponse(
|
||||
const Message& request,
|
||||
int64_t requestId,
|
||||
const std::string& exceptionStr) {
|
||||
std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
|
||||
return Message(
|
||||
std::move(payload),
|
||||
std::vector<torch::Tensor>(),
|
||||
MessageType::EXCEPTION,
|
||||
request.id());
|
||||
requestId);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -118,13 +118,12 @@ class TORCH_API Message final {
|
||||
// Create a response Message with an exception for the provided request message.
|
||||
// The exception string representation will be used as the message's payload.
|
||||
TORCH_API Message
|
||||
createExceptionResponse(const Message& request, const std::exception& e);
|
||||
createExceptionResponse(int64_t requestId, const std::exception& e);
|
||||
|
||||
// Create a response Message with an exception type for the provided request
|
||||
// message. The passed in string will be used as the created message's payload
|
||||
TORCH_API Message createExceptionResponse(
|
||||
const Message& request,
|
||||
const std::string& exceptionStr);
|
||||
TORCH_API Message
|
||||
createExceptionResponse(int64_t requestId, const std::string& exceptionStr);
|
||||
|
||||
typedef torch::utils::Future<Message> FutureMessage;
|
||||
|
||||
|
@ -416,7 +416,7 @@ void ProcessGroupAgent::enqueueSend(SendWork work) {
|
||||
"Encountered exception in ProcessGroupAgent::enqueueSend: ",
|
||||
e.what());
|
||||
auto exceptionMsg =
|
||||
rpc::createExceptionResponse(work.message_, errorStr);
|
||||
rpc::createExceptionResponse(work.message_.id(), errorStr);
|
||||
if (work.message_.isRequest()) {
|
||||
markFutureWithError(exceptionMsg);
|
||||
} else if (work.message_.isResponse()) {
|
||||
@ -455,7 +455,7 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) {
|
||||
send(
|
||||
work.from_,
|
||||
createExceptionResponse(
|
||||
message, futureResponse->error()->what()));
|
||||
message.id(), futureResponse->error()->what()));
|
||||
}
|
||||
} else {
|
||||
++serverActiveAsyncCalls_;
|
||||
@ -685,8 +685,7 @@ void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
"RPC ran for more than ",
|
||||
timedOutFuture.timeout_.count(),
|
||||
" milliseconds and timed out.");
|
||||
const auto exceptionMsg = createExceptionResponse(
|
||||
Message({}, {}, MessageType::EXCEPTION), errorStr);
|
||||
const auto exceptionMsg = createExceptionResponse(-1, errorStr);
|
||||
if (!timedOutFuture.future_->hasError()) {
|
||||
--clientActiveCalls_;
|
||||
timedOutFuture.future_->setError(std::string(
|
||||
|
@ -113,6 +113,10 @@ bool PyRRef::isOwner() const {
|
||||
return rref_->isOwner();
|
||||
}
|
||||
|
||||
bool PyRRef::confirmedByOwner() const {
|
||||
return rref_->confirmedByOwner();
|
||||
}
|
||||
|
||||
WorkerInfo PyRRef::owner() const {
|
||||
return RRefContext::getInstance().agent()->getWorkerInfo(rref_->owner());
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ class PyRRef {
|
||||
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
|
||||
|
||||
bool isOwner() const;
|
||||
bool confirmedByOwner() const;
|
||||
WorkerInfo owner() const;
|
||||
py::object toHere();
|
||||
py::object localValue();
|
||||
|
@ -9,45 +9,14 @@ namespace rpc {
|
||||
|
||||
using namespace torch::distributed::autograd;
|
||||
|
||||
namespace {
|
||||
|
||||
// When request message has autograd info, processMessage() will set up valid
|
||||
// current context id properly. This struct is used to clean up current context
|
||||
// id after processMessage() is done.
|
||||
struct ClearAutogradContextGuard {
|
||||
ClearAutogradContextGuard() = default;
|
||||
~ClearAutogradContextGuard() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
auto& autogradContainer = DistAutogradContainer::getInstance();
|
||||
autogradContainer.clearCurrentContext();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::shared_ptr<FutureMessage> RequestCallback::operator()(
|
||||
Message& request) const {
|
||||
// For a recv thread, current context id should be invalid outside
|
||||
// processMessage().
|
||||
ClearAutogradContextGuard guard;
|
||||
try {
|
||||
return processMessage(request);
|
||||
} catch (std::exception& e) {
|
||||
LOG(ERROR) << "Received error while processing request type "
|
||||
<< request.type() << ": " << e.what();
|
||||
// Adding node information to the error here since all processed RPC
|
||||
// requests should be going through this function.
|
||||
std::string errorMsg = c10::str(
|
||||
"Error on Node ",
|
||||
DistAutogradContainer::getInstance().getWorkerId(),
|
||||
": ",
|
||||
e.what());
|
||||
return std::make_shared<FutureMessage>(
|
||||
createExceptionResponse(request, errorMsg));
|
||||
}
|
||||
// NB: cannot clear autograd context id here because the processMessage method
|
||||
// might pause waiting for all RRefs in the arguments to be confirmed by their
|
||||
// owners and resumne processing in a different thread. Hence, the
|
||||
// thread_local context id needs to be set and cleared in the thread that
|
||||
// indeed carries out the processing logic.
|
||||
return processMessage(request);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -20,6 +20,8 @@
|
||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
||||
#include <torch/csrc/distributed/rpc/script_resp.h>
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
#include <torch/csrc/jit/python/pybind_utils.h>
|
||||
|
||||
@ -29,13 +31,86 @@ namespace rpc {
|
||||
|
||||
using namespace torch::distributed::autograd;
|
||||
|
||||
std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
|
||||
RpcCommandBase& rpc,
|
||||
MessageType messageType,
|
||||
const MessageType& messageType) {
|
||||
switch (messageType) {
|
||||
case MessageType::PYTHON_CALL: {
|
||||
auto& pc = static_cast<PythonCall&>(rpc);
|
||||
return std::make_unique<UnpickledPythonCall>(pc.serializedPyObj());
|
||||
}
|
||||
case MessageType::PYTHON_REMOTE_CALL: {
|
||||
auto& prc = static_cast<PythonRemoteCall&>(rpc);
|
||||
return std::make_unique<UnpickledPythonRemoteCall>(
|
||||
prc.serializedPyObj(), prc.retRRefId(), prc.retForkId());
|
||||
}
|
||||
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
||||
// Deserialize the wrapped RPC if it contains Python UDF
|
||||
auto& rwa = static_cast<RpcWithAutograd&>(rpc);
|
||||
auto& wrappedRpc = rwa.wrappedRpc();
|
||||
auto pythonRpc = deserializePythonRpcCommandReference(
|
||||
wrappedRpc, rwa.wrappedMessageType());
|
||||
if (pythonRpc) {
|
||||
rwa.setWrappedRpc(std::move(pythonRpc));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
default: {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
|
||||
std::unique_ptr<RpcCommandBase> rpc,
|
||||
const MessageType& messageType) {
|
||||
auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
|
||||
return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
|
||||
}
|
||||
|
||||
// When request message has autograd info, processMessage() will set up valid
|
||||
// current context id properly. This struct is used to clean up current context
|
||||
// id after processMessage() is done.
|
||||
struct ClearAutogradContextGuard {
|
||||
ClearAutogradContextGuard() = default;
|
||||
~ClearAutogradContextGuard() {
|
||||
clear();
|
||||
}
|
||||
|
||||
void clear() {
|
||||
auto& autogradContainer = DistAutogradContainer::getInstance();
|
||||
autogradContainer.clearCurrentContext();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Message RequestCallbackImpl::handleError(
|
||||
const std::exception& e,
|
||||
const MessageType messageType,
|
||||
int64_t messageId) const {
|
||||
auto wrap = [messageId](Message m) {
|
||||
LOG(ERROR) << "Received error while processing request type " << messageType
|
||||
<< ": " << e.what();
|
||||
// Adding node information to the error here since all processed RPC
|
||||
// requests should be going through this function.
|
||||
std::string errorMsg = c10::str(
|
||||
"Error on Node ",
|
||||
DistAutogradContainer::getInstance().getWorkerId(),
|
||||
": ",
|
||||
e.what());
|
||||
return createExceptionResponse(messageId, errorMsg);
|
||||
}
|
||||
|
||||
void RequestCallbackImpl::processRpc(
|
||||
RpcCommandBase& rpc,
|
||||
const MessageType& messageType,
|
||||
const int64_t messageId,
|
||||
const std::shared_ptr<FutureMessage>& responseFuture) const {
|
||||
auto markComplete = [messageId, &responseFuture](Message m) {
|
||||
m.setId(messageId);
|
||||
return std::make_shared<FutureMessage>(std::move(m));
|
||||
responseFuture->markCompleted(std::move(m));
|
||||
};
|
||||
|
||||
// TODO: RpcCommandBase should have an abstract execute() method that we can
|
||||
@ -65,22 +140,22 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
"TorchScript function should be a single IValue, got a vector of "
|
||||
"size ",
|
||||
stack.size());
|
||||
|
||||
return wrap(std::move(ScriptResp(std::move(stack.front()))).toMessage());
|
||||
markComplete(std::move(ScriptResp(std::move(stack.front()))).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::PYTHON_CALL: {
|
||||
auto& pyCall = static_cast<PythonCall&>(rpc);
|
||||
auto& upc = static_cast<UnpickledPythonCall&>(rpc);
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
std::shared_ptr<SerializedPyObj> serializedPyObj = nullptr;
|
||||
{
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
auto pythonUdf = pythonRpcHandler.deserialize(pyCall.serializedPyObj());
|
||||
serializedPyObj =
|
||||
std::make_shared<SerializedPyObj>(pythonRpcHandler.serialize(
|
||||
pythonRpcHandler.runPythonUdf(std::move(pythonUdf))));
|
||||
pythonRpcHandler.runPythonUdf(std::move(upc).movePythonUdf())));
|
||||
}
|
||||
return wrap(
|
||||
markComplete(
|
||||
std::move(PythonResp(std::move(*serializedPyObj))).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::SCRIPT_REMOTE_CALL: {
|
||||
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
|
||||
@ -135,13 +210,14 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
// rrefId (OwnerRRef does not have a forkId anyway).
|
||||
ctx.addForkOfOwner(rrefId, forkId);
|
||||
}
|
||||
return wrap(RemoteRet(rrefId, forkId).toMessage());
|
||||
markComplete(RemoteRet(rrefId, forkId).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::PYTHON_REMOTE_CALL: {
|
||||
auto& prc = static_cast<PythonRemoteCall&>(rpc);
|
||||
auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
|
||||
|
||||
auto rrefId = RRefId::fromIValue(prc.retRRefId());
|
||||
auto forkId = ForkId::fromIValue(prc.retForkId());
|
||||
const auto& rrefId = uprc.rrefId();
|
||||
const auto& forkId = uprc.forkId();
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
|
||||
auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get());
|
||||
@ -150,9 +226,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
IValue py_ivalue;
|
||||
{
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
auto pythonUdf = pythonRpcHandler.deserialize(prc.serializedPyObj());
|
||||
py_ivalue = jit::toIValue(
|
||||
pythonRpcHandler.runPythonUdf(std::move(pythonUdf)),
|
||||
pythonRpcHandler.runPythonUdf(std::move(uprc).movePythonUdf()),
|
||||
PyObjectType::get());
|
||||
}
|
||||
|
||||
@ -169,32 +244,33 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
// rrefId (OwnerRRef does not have a forkId anyway).
|
||||
ctx.addForkOfOwner(rrefId, forkId);
|
||||
}
|
||||
return wrap(RemoteRet(rrefId, forkId).toMessage());
|
||||
markComplete(RemoteRet(rrefId, forkId).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
||||
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
c10::intrusive_ptr<OwnerRRef> rref = ctx.getOwnerRRef(srf.rrefId());
|
||||
if (rref->hasValue()) { // optional fast-path
|
||||
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
|
||||
markComplete(ScriptRRefFetchRet({rref->getValue()}).toMessage());
|
||||
return;
|
||||
} else {
|
||||
auto whenValueSet = rref->getFuture();
|
||||
// Our response is satisfied when the rpcs come back.
|
||||
whenValueSet->addCallback(
|
||||
[responseFuture, messageId, rref](
|
||||
const rpc::Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& error) {
|
||||
if (!error) {
|
||||
Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage();
|
||||
m.setId(messageId);
|
||||
responseFuture->markCompleted(std::move(m));
|
||||
} else {
|
||||
responseFuture->setError(error->what());
|
||||
}
|
||||
});
|
||||
}
|
||||
auto whenValueSet = rref->getFuture();
|
||||
auto responseFuture = std::make_shared<FutureMessage>();
|
||||
|
||||
// Our response is satisfied when the rpcs come back.
|
||||
whenValueSet->addCallback(
|
||||
[responseFuture, messageId, rref](
|
||||
const rpc::Message& /* unused */,
|
||||
const c10::optional<utils::FutureError>& error) {
|
||||
if (!error) {
|
||||
Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage();
|
||||
m.setId(messageId);
|
||||
responseFuture->markCompleted(std::move(m));
|
||||
} else {
|
||||
responseFuture->setError(error->what());
|
||||
}
|
||||
});
|
||||
return responseFuture;
|
||||
return;
|
||||
}
|
||||
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
||||
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
|
||||
@ -209,12 +285,12 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
}
|
||||
SerializedPyObj result =
|
||||
PythonRpcHandler::getInstance().serialize(pyValue);
|
||||
return wrap(
|
||||
markComplete(
|
||||
PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
auto whenValueSet = rref->getFuture();
|
||||
auto responseFuture = std::make_shared<FutureMessage>();
|
||||
|
||||
// Our response is satisfied when the rpcs come back.
|
||||
whenValueSet->addCallback(
|
||||
@ -238,7 +314,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
responseFuture->setError(error->what());
|
||||
}
|
||||
});
|
||||
return responseFuture;
|
||||
return;
|
||||
}
|
||||
case MessageType::RREF_USER_DELETE: {
|
||||
auto& rud = static_cast<RRefUserDelete&>(rpc);
|
||||
@ -248,19 +324,22 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
deletedRRef.reset();
|
||||
}
|
||||
return wrap(std::move(RRefAck()).toMessage());
|
||||
markComplete(std::move(RRefAck()).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::RREF_CHILD_ACCEPT: {
|
||||
auto& rca = static_cast<RRefChildAccept&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
ctx.delPendingChild(rca.forkId());
|
||||
return wrap(std::move(RRefAck()).toMessage());
|
||||
markComplete(std::move(RRefAck()).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::RREF_FORK_REQUEST: {
|
||||
auto& rfr = static_cast<RRefForkRequest&>(rpc);
|
||||
auto& ctx = RRefContext::getInstance();
|
||||
ctx.addForkOfOwner(rfr.rrefId(), rfr.forkId());
|
||||
return wrap(RRefAck().toMessage());
|
||||
markComplete(RRefAck().toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
||||
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
|
||||
@ -283,13 +362,16 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
|
||||
// Process the original RPC.
|
||||
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
|
||||
// Make an overall future for the wrapped response.
|
||||
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>();
|
||||
// Kick off processing for the nested future and get a Future<T> to the
|
||||
// result.
|
||||
auto wrappedRpcResponseFuture = processRpc(
|
||||
rpcWithAutograd.wrappedRpc(), wrappedMessageType, messageId);
|
||||
processRpc(
|
||||
rpcWithAutograd.wrappedRpc(),
|
||||
wrappedMessageType,
|
||||
messageId,
|
||||
wrappedRpcResponseFuture);
|
||||
|
||||
// Make an overall future for the wrapped response.
|
||||
auto responseFuture = std::make_shared<rpc::FutureMessage>();
|
||||
auto fromWorkerId = rpcWithAutograd.fromWorkerId();
|
||||
// The original future needs to be marked as completed when the wrapped
|
||||
// one completes, with the autograd context information wrapped.
|
||||
@ -309,7 +391,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
responseFuture->markCompleted(std::move(msg));
|
||||
}
|
||||
});
|
||||
return responseFuture;
|
||||
return;
|
||||
}
|
||||
case MessageType::BACKWARD_AUTOGRAD_REQ: {
|
||||
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
|
||||
@ -328,8 +410,6 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
// Attach the gradients to the send function.
|
||||
sendFunction->setGrads(gradientsCall.getGrads());
|
||||
|
||||
auto responseFuture = std::make_shared<rpc::FutureMessage>();
|
||||
|
||||
// Now execute the autograd graph using the "distributed engine."
|
||||
auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
|
||||
autogradContext, sendFunction, gradientsCall.retainGraph());
|
||||
@ -347,7 +427,7 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
responseFuture->setError(error->what());
|
||||
}
|
||||
});
|
||||
return responseFuture;
|
||||
return;
|
||||
};
|
||||
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
|
||||
auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
|
||||
@ -358,7 +438,8 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
// notified to clean up their context.
|
||||
DistAutogradContainer::getInstance().releaseContextIfPresent(
|
||||
cleanupContextId);
|
||||
return wrap(std::move(CleanupAutogradContextResp()).toMessage());
|
||||
markComplete(std::move(CleanupAutogradContextResp()).toMessage());
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
@ -369,8 +450,42 @@ std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
|
||||
|
||||
std::shared_ptr<FutureMessage> RequestCallbackImpl::processMessage(
|
||||
Message& request) const {
|
||||
std::unique_ptr<RpcCommandBase> rpc = deserializeRequest(request);
|
||||
return processRpc(*rpc, request.type(), request.id());
|
||||
// We need two futures here because it could pause twice when processing a
|
||||
// RPC message:
|
||||
// 1) waiting for all RRefs in the arguments to become confirmed;
|
||||
// 2) waiting for processRpc to finish.
|
||||
auto retFuture = std::make_shared<FutureMessage>();
|
||||
auto& rrefContext = RRefContext::getInstance();
|
||||
try {
|
||||
rrefContext.recordThreadLocalPendingRRefs();
|
||||
std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
|
||||
deserializeRequest(request), request.type());
|
||||
auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
|
||||
|
||||
rrefsReadyFuture->addCallback(
|
||||
[this,
|
||||
retFuture,
|
||||
// std::function must be copyable, hence hae to cast the unique_ptr to
|
||||
// a shared_ptr here.
|
||||
rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
|
||||
messageType = request.type(),
|
||||
id = request.id()](
|
||||
const bool& /*unused*/,
|
||||
const c10::optional<utils::FutureError>& /*unused*/) {
|
||||
try {
|
||||
// For a recv thread, current context id should be invalid outside
|
||||
// processMessage().
|
||||
ClearAutogradContextGuard guard;
|
||||
processRpc(*rpc, messageType, id, retFuture);
|
||||
} catch (std::exception& e) {
|
||||
retFuture->markCompleted(handleError(e, messageType, id));
|
||||
}
|
||||
});
|
||||
} catch (std::exception& e) {
|
||||
retFuture->markCompleted(handleError(e, request.type(), request.id()));
|
||||
rrefContext.clearRecordedPendingRRefsOnError();
|
||||
}
|
||||
return retFuture;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -14,9 +14,15 @@ class TORCH_API RequestCallbackImpl : public RequestCallback {
|
||||
Message& request) const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<FutureMessage> processRpc(
|
||||
void processRpc(
|
||||
RpcCommandBase& rpc,
|
||||
MessageType messageType,
|
||||
const MessageType& messageType,
|
||||
const int64_t messageId,
|
||||
const std::shared_ptr<FutureMessage>& retFutureMessagge) const;
|
||||
|
||||
Message handleError(
|
||||
const std::exception& e,
|
||||
const MessageType messageType,
|
||||
int64_t messageId) const;
|
||||
};
|
||||
|
||||
|
@ -7,6 +7,10 @@ namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>>
|
||||
RRefContext::userTable_;
|
||||
thread_local bool RRefContext::recording = false;
|
||||
|
||||
namespace callback {
|
||||
void confirmPendingUser(
|
||||
const rpc::Message& message,
|
||||
@ -438,38 +442,104 @@ void RRefContext::addPendingUser(
|
||||
const c10::intrusive_ptr<RRef>& rref) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
|
||||
|
||||
auto state = std::make_shared<PendingUserState>(rref);
|
||||
if (recording) {
|
||||
// adding and waiting for pending users are guaranteed to be called from the
|
||||
// same thread, but deleting pending users will be called from another
|
||||
// thread. As the delPendingUser will not be able to access the same
|
||||
// thread_local variable, we cannot address this problem by making
|
||||
// pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share
|
||||
// the same PendingUserState shared_ptr.
|
||||
userTable_.push_back(state);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
pendingUsers_.find(forkId) == pendingUsers_.end(),
|
||||
"Inconsistent states: attempt to add the same UserRRef twice.");
|
||||
pendingUsers_[forkId] = rref;
|
||||
|
||||
pendingUsers_.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(forkId),
|
||||
std::forward_as_tuple(state));
|
||||
}
|
||||
|
||||
void RRefContext::delPendingUser(const ForkId& forkId) {
|
||||
c10::intrusive_ptr<RRef> deletedUser;
|
||||
std::shared_ptr<PendingUserState> deletedState = nullptr;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto iter = pendingUsers_.find(forkId);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
iter != pendingUsers_.end(),
|
||||
"Inconsistent states: attempt to delete a non-exist UserRRef.");
|
||||
|
||||
// There are two reasons for keeping the deleted PendingUserState alive
|
||||
// until exiting the critical section.
|
||||
// (1) Since this UserRRef is removed from the map, the refcount of this
|
||||
// UserRRef could reach to 0. So the resource destructor
|
||||
// (`release_resources()`) might be called, in which the lock is
|
||||
// acquired again. Hence, it must be destructed with the lock released.
|
||||
// To meet this constraint, we intentionally create a temporary pointer
|
||||
// to increase the refcount of the deleted PendingUserState, extending
|
||||
// its lifetime untill lock released.
|
||||
// (2) Since #34497, a user function only runs after all RRefs in the
|
||||
// arguments are confirmed by their owners, which is done by adding the
|
||||
// RPC processing logic as a callback to the UserRRef ready future. So,
|
||||
// calling `confirm` on the PendingUserState could trigger pending user
|
||||
// functions, which might in turn acquire the lock in RRefContext.
|
||||
// Hence, we must release the lock to prevent deadlock.
|
||||
// NB: Another option is to use reentrant lock. However, it is better for
|
||||
// the developers to fully understand the locking behavior instead of
|
||||
// hiding the subtle logic using a reentrant lock.
|
||||
deletedState = iter->second; // Increase refcount
|
||||
|
||||
confirmedUsers_.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(forkId),
|
||||
std::forward_as_tuple(iter->second));
|
||||
|
||||
// Since this UserRRef is removed from the map,
|
||||
// the refcount of this UserRRef could reach to 0,
|
||||
// so the "destructor", `release_resources()`, might be called,
|
||||
// in which the lock is acquired again.
|
||||
// So it must be destructed with the lock released.
|
||||
// Meet this constraint by creating a temporary pointer to increase the
|
||||
// refcount, extending its lifetime untill lock released.
|
||||
deletedUser = iter->second; // Increase refcount.
|
||||
std::forward_as_tuple(iter->second->rref_));
|
||||
pendingUsers_.erase(iter); // Decrease refcount.
|
||||
}
|
||||
deletedState->confirm();
|
||||
deleteAllUsersCV_.notify_all();
|
||||
deletedUser.reset(); // Decrease refcount.
|
||||
deletedState.reset(); // Decrease refcount.
|
||||
}
|
||||
|
||||
void RRefContext::recordThreadLocalPendingRRefs() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
userTable_.empty(),
|
||||
"User RRef Table should be empty when start recording");
|
||||
recording = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<torch::utils::Future<bool>> RRefContext::
|
||||
waitForThreadLocalPendingRRefs() {
|
||||
auto future = std::make_shared<torch::utils::Future<bool>>();
|
||||
if (userTable_.empty()) {
|
||||
future->markCompleted(true);
|
||||
} else {
|
||||
auto remainingRRefs =
|
||||
std::make_shared<std::atomic<uint64_t>>(userTable_.size());
|
||||
for (auto& state : userTable_) {
|
||||
state->future_.addCallback(
|
||||
[future, remainingRRefs](
|
||||
const bool& /* unused */,
|
||||
const c10::optional<utils::FutureError>& /* unused */) {
|
||||
auto localCount = remainingRRefs->fetch_sub(1);
|
||||
if (localCount == 1) {
|
||||
future->markCompleted(true);
|
||||
}
|
||||
});
|
||||
}
|
||||
userTable_.clear();
|
||||
}
|
||||
recording = false;
|
||||
return future;
|
||||
}
|
||||
|
||||
void RRefContext::clearRecordedPendingRRefsOnError() {
|
||||
userTable_.clear();
|
||||
recording = false;
|
||||
}
|
||||
|
||||
void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/rref_impl.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/utils/future.h>
|
||||
|
||||
#include <atomic>
|
||||
|
||||
@ -147,6 +148,27 @@ class TORCH_API RRefContext {
|
||||
const c10::intrusive_ptr<RRef>& rref);
|
||||
void delPendingUser(const ForkId& forkId);
|
||||
|
||||
// Start recroding new pending UserRRefs. All pending UserRRefs introduced
|
||||
// after this point will be put into the thread_local userTable_, which will
|
||||
// then be consumed and cleared in waitForThreadLocalPendingRRefs().
|
||||
void recordThreadLocalPendingRRefs();
|
||||
// End recording new pending UserRRefs, and clear the thread_local userTable_.
|
||||
// Returns a Future which will be marked as completed when all pending
|
||||
// UserRRefs in the current userTable_ are confirmed by their owners. The bool
|
||||
// value in the Future is unused.
|
||||
// This method is useful to make sure RRefs in user function arguments are
|
||||
// confirmed before launching user code.
|
||||
// NB: Callers of this method does not need to keep the returned Future alive,
|
||||
// because this Future is already captured in callbacks of the
|
||||
// PendingUserState. If there is no pending UserRRefs, this method returns a
|
||||
// completed future.
|
||||
std::shared_ptr<torch::utils::Future<bool>> waitForThreadLocalPendingRRefs();
|
||||
// Only call this function when there are errors during a recording session,
|
||||
// and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked
|
||||
// properly.
|
||||
// TODO: make this a context guard
|
||||
void clearRecordedPendingRRefsOnError();
|
||||
|
||||
void delUser(
|
||||
const worker_id_t owner,
|
||||
const RRefId& rrefId,
|
||||
@ -156,6 +178,20 @@ class TORCH_API RRefContext {
|
||||
std::unordered_map<std::string, std::string> getDebugInfo();
|
||||
|
||||
private:
|
||||
struct PendingUserState {
|
||||
PendingUserState(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {}
|
||||
|
||||
inline void confirm() {
|
||||
c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm();
|
||||
future_.markCompleted(true);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<RRef> rref_;
|
||||
// Use Future.wait() and Future.markCompleted() to block and unblock user
|
||||
// functions. The bool value wrapped by the future_ is not used.
|
||||
torch::utils::Future<bool> future_;
|
||||
};
|
||||
|
||||
RRefContext(std::shared_ptr<RpcAgent>);
|
||||
|
||||
c10::intrusive_ptr<UserRRef> createUserRRef(
|
||||
@ -209,7 +245,7 @@ class TORCH_API RRefContext {
|
||||
// It can be used or shared, but cannot be deleted, and hence kept alive
|
||||
// in this map. A message of type RREF_USER_ACCEPT will move the
|
||||
// corresponding RRef from pendingUsers_ map to confirmedUsers_ map.
|
||||
std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
|
||||
std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash>
|
||||
pendingUsers_;
|
||||
// UserRRefs are added into this map when it is confirmed by the owner.
|
||||
// When destroying RRefContext this map helps to find local UserRRefs
|
||||
@ -229,6 +265,34 @@ class TORCH_API RRefContext {
|
||||
|
||||
std::mutex destroyedMutex_;
|
||||
bool destroyed_;
|
||||
|
||||
// Thread local states to keep UserRRefs deserialized from user function
|
||||
// arguments.
|
||||
static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_;
|
||||
// A flag indicating whether subsequently created UserRRefs should be added to
|
||||
// the thread_local userTable_. The flag is set to true before serializing
|
||||
// RPC arguments and then set to false before running the corresponding
|
||||
// user code. See addPendingUser and delPendingUser for more details.
|
||||
// NB: The reason for having this flag is because addPendingUser are called in
|
||||
// two cases, and we only want to track the 2nd case.
|
||||
// (1) RRef as the return value: when calling rpc.remote, the UserRRef on the
|
||||
// caller side is added to the context using addPendingUser.
|
||||
// (2) RRef as an argument: When running an RPC using RRefs as arguments, the
|
||||
// RRef is forwarded to the callee as new UserRRefs (if the callee is not
|
||||
// the owner). In this case, we block running the user function until all
|
||||
// UserRRefs are confirmed by the owner.
|
||||
// This contract gurantees that no UserRRefs can be used remotely without
|
||||
// confirmation. Note that, however, the UserRRef created by rpc.remote can
|
||||
// still be passed to local functions as arguments and used there. This is by
|
||||
// design, because this feature is especially useful when, say a master node
|
||||
// creates multiple UserRRefs in a loop and then shares them with other nodes.
|
||||
// Blocking every iteration in the loop until RRefs are confirmed will slow
|
||||
// this down. This nuance on UserRRef can be interpreted as we only make
|
||||
// exceptions for UserRRef creators. And using the UserRRef on its creator
|
||||
// without confirmation is OK, because the creator would either call to_here
|
||||
// or forward the UserRRef, and both would then require confirmations from the
|
||||
// owner.
|
||||
static thread_local bool recording;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
@ -70,7 +70,9 @@ UserRRef::UserRRef(
|
||||
const RRefId& rrefId,
|
||||
const ForkId& forkId,
|
||||
TypePtr type)
|
||||
: RRef(ownerId, rrefId, std::move(type)), forkId_(forkId) {
|
||||
: RRef(ownerId, rrefId, std::move(type)),
|
||||
forkId_(forkId),
|
||||
confirmedByOwner_(false) {
|
||||
// Do nothing,
|
||||
// (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
|
||||
// a RREF_FORK_REQUEST message to the owner.
|
||||
|
@ -257,6 +257,10 @@ class TORCH_API UserRRef final : public RRef {
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool confirmedByOwner() const override {
|
||||
return confirmedByOwner_;
|
||||
}
|
||||
|
||||
// Returns the globally unique ForkId of this RRef
|
||||
const ForkId& forkId() const;
|
||||
|
||||
@ -280,6 +284,9 @@ class TORCH_API UserRRef final : public RRef {
|
||||
friend class RRefContext;
|
||||
|
||||
RRefForkData fork() const override;
|
||||
inline void confirm() {
|
||||
confirmedByOwner_ = true;
|
||||
}
|
||||
|
||||
const ForkId forkId_;
|
||||
|
||||
@ -289,6 +296,8 @@ class TORCH_API UserRRef final : public RRef {
|
||||
// proactive cleanup on RPC graceful shutdown.
|
||||
std::mutex deletedOnOwnerMutex_;
|
||||
bool deletedOnOwner_{false};
|
||||
// Indicating whether this UserRRef has been confirmed by its owner.
|
||||
std::atomic<bool> confirmedByOwner_;
|
||||
};
|
||||
|
||||
// Keep the template only on the derived class because ``RRefContext`` needs to
|
||||
@ -316,6 +325,12 @@ class TORCH_API OwnerRRef final : public RRef {
|
||||
return true;
|
||||
}
|
||||
|
||||
// OwnerRRef is always confirmed, while UserRRef is only confirmed when the
|
||||
// owner knows about it.
|
||||
inline bool confirmedByOwner() const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get a constant reference of the real value. This method will block if the
|
||||
// value is not ready. This method does not need GIL as it does not create
|
||||
// any new py::object.
|
||||
|
28
torch/csrc/distributed/rpc/unpickled_python_call.cpp
Normal file
28
torch/csrc/distributed/rpc/unpickled_python_call.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
UnpickledPythonCall::UnpickledPythonCall(
|
||||
const SerializedPyObj& serializedPyObj) {
|
||||
auto& pythonRpcHandler = PythonRpcHandler::getInstance();
|
||||
pybind11::gil_scoped_acquire ag;
|
||||
pythonUdf_ = pythonRpcHandler.deserialize(serializedPyObj);
|
||||
}
|
||||
|
||||
Message UnpickledPythonCall::toMessage() && {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "UnpickledPythonCall does not support toMessage().");
|
||||
}
|
||||
|
||||
py::object UnpickledPythonCall::movePythonUdf() && {
|
||||
return std::move(pythonUdf_);
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
33
torch/csrc/distributed/rpc/unpickled_python_call.h
Normal file
33
torch/csrc/distributed/rpc/unpickled_python_call.h
Normal file
@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
// This class converts the content in a PythonCall into py::object. This is a
|
||||
// helper class to make sure that all arguments deserialization is done before
|
||||
// entering RequestCallbackImpl::processRpc(...), so that the deserialization
|
||||
// related logic can be carried out in one spot instead of scattered in multiple
|
||||
// places for different message types.
|
||||
// NB: The reason for not consolidating class into PythonCall is because
|
||||
// PythonCall is a libtorch type which should not depend on Python types.
|
||||
class TORCH_API UnpickledPythonCall : public RpcCommandBase {
|
||||
public:
|
||||
explicit UnpickledPythonCall(const SerializedPyObj& serializedPyObj);
|
||||
|
||||
// toMessage() method is not implemented, as objects of this class should
|
||||
// never be directly converted into a Message object.
|
||||
Message toMessage() && override;
|
||||
py::object movePythonUdf() &&;
|
||||
|
||||
private:
|
||||
py::object pythonUdf_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
28
torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp
Normal file
28
torch/csrc/distributed/rpc/unpickled_python_remote_call.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
UnpickledPythonRemoteCall::UnpickledPythonRemoteCall(
|
||||
const SerializedPyObj& serializedPyObj,
|
||||
const at::IValue& rrefId,
|
||||
const at::IValue& forkId)
|
||||
: UnpickledPythonCall(serializedPyObj),
|
||||
rrefId_(RRefId::fromIValue(rrefId)),
|
||||
forkId_(ForkId::fromIValue(forkId)) {}
|
||||
|
||||
const RRefId& UnpickledPythonRemoteCall::rrefId() const {
|
||||
return rrefId_;
|
||||
}
|
||||
|
||||
const ForkId& UnpickledPythonRemoteCall::forkId() const {
|
||||
return forkId_;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
36
torch/csrc/distributed/rpc/unpickled_python_remote_call.h
Normal file
36
torch/csrc/distributed/rpc/unpickled_python_remote_call.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
// This class converts the content in a PythonRemoteCall into py::object. This
|
||||
// is a helper class to make sure that all arguments deserialization is done
|
||||
// before entering RequestCallbackImpl::processRpc(...), so that the
|
||||
// deserialization related logic can be carried out in one spot instead of
|
||||
// scattered in multiple places for different message types.
|
||||
// NB: The reason for not consolidating class into PythonRemoteCall is because
|
||||
// PythonRemoteCall is a libtorch type which should not depend on Python types.
|
||||
class TORCH_API UnpickledPythonRemoteCall final : public UnpickledPythonCall {
|
||||
public:
|
||||
explicit UnpickledPythonRemoteCall(
|
||||
const SerializedPyObj& serializedPyObj,
|
||||
const at::IValue& retRRefId,
|
||||
const at::IValue& retForkId);
|
||||
|
||||
const RRefId& rrefId() const;
|
||||
const ForkId& forkId() const;
|
||||
|
||||
private:
|
||||
RRefId rrefId_;
|
||||
ForkId forkId_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -63,6 +63,14 @@ RegisterOperators reg_rpc_ops({
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"aten::confirmed_by_owner(RRef(t) self) -> bool",
|
||||
[](Stack& stack) {
|
||||
auto rref = pop(stack).toRRef();
|
||||
push(stack, rref->confirmedByOwner());
|
||||
return 0;
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
prim::rpc_async,
|
||||
[](const Node* node) -> Operation {
|
||||
|
@ -97,7 +97,7 @@ class LocalRRefTest(RpcAgentTestFixture):
|
||||
return
|
||||
|
||||
# Create a local RRef<MyModuleInterface>.
|
||||
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
|
||||
rref_script_module = rpc.RRef(MyScriptModule(self.rank), MyModuleInterface)
|
||||
ret = rref_script_module.to_here().forward()
|
||||
self.assertEqual(ret, torch.ones(self.rank))
|
||||
|
||||
@ -535,6 +535,11 @@ def rref_script_annotation(rref_var):
|
||||
return rref_python_annotation(rref_var).to_here()
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def script_check_rref_confirmed(rref):
|
||||
# type: (RRef[Tensor]) -> bool
|
||||
return rref.confirmed_by_owner()
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch._six.PY3, "Pytorch distributed rpc package does not support python2"
|
||||
)
|
||||
@ -673,3 +678,34 @@ class JitRpcTest(LocalRRefTest, JitRpcAsyncOpTest, RpcAgentTestFixture):
|
||||
|
||||
res = rref_script_annotation(rref_var)
|
||||
self.assertEqual(res, torch.ones(2, 2) + 1)
|
||||
|
||||
|
||||
def _create_rref(self):
|
||||
owner_rank = (self.rank + 2) % self.world_size
|
||||
return rpc.remote(
|
||||
"worker{}".format(owner_rank),
|
||||
torch.add,
|
||||
args=(torch.zeros(2, 2), 1)
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_user_rrefs_confirmed(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
rref = self._create_rref()
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
script_check_rref_confirmed,
|
||||
args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret, True)
|
||||
|
||||
@dist_init
|
||||
def test_user_rrefs_confirmed_remote(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
rref = self._create_rref()
|
||||
ret_rref = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
script_check_rref_confirmed,
|
||||
args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret_rref.to_here(), True)
|
||||
|
@ -259,6 +259,9 @@ def clear_global_rref():
|
||||
global_rref = None
|
||||
|
||||
|
||||
def check_rref_confirmed(rref):
|
||||
return rref.confirmed_by_owner()
|
||||
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -1789,3 +1792,33 @@ class RpcTest(RpcAgentTestFixture):
|
||||
# Sending to self should fail too.
|
||||
with self.assertRaisesRegex(RuntimeError, "RPC backend only supports CPU tensors.*Found tensor on device: cuda:0"):
|
||||
rpc.rpc_sync(worker_name(self.rank), torch.add, args=(t1, t2))
|
||||
|
||||
def _create_rref(self):
|
||||
owner_rank = (self.rank + 2) % self.world_size
|
||||
return rpc.remote(
|
||||
"worker{}".format(owner_rank),
|
||||
torch.add,
|
||||
args=(torch.zeros(2, 2), 1)
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_user_rrefs_confirmed(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
rref = self._create_rref()
|
||||
ret = rpc.rpc_sync(
|
||||
"worker{}".format(dst_rank),
|
||||
check_rref_confirmed,
|
||||
args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret, True)
|
||||
|
||||
@dist_init
|
||||
def test_user_rrefs_confirmed_remote(self):
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
rref = self._create_rref()
|
||||
ret_rref = rpc.remote(
|
||||
"worker{}".format(dst_rank),
|
||||
check_rref_confirmed,
|
||||
args=(rref,)
|
||||
)
|
||||
self.assertEqual(ret_rref.to_here(), True)
|
||||
|
Reference in New Issue
Block a user