Files
pytorch/torch/csrc/distributed/rpc/request_callback_impl.cpp
Michael Suo dbe850af5b [jit] do the code reorg (#33851)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33851

Rationale and context described in #33828.

Script to reproduce the move:
https://gist.github.com/suo/16cbefaaeb67ca5a7c6caffd49b7f6e9
ghstack-source-id: 99079645

Test Plan: Make sure CI passes

Reviewed By: jamesr66a

Differential Revision: D20133869

fbshipit-source-id: 390e9241a9c85366d9005c492ac31f10aa96488e
2020-02-27 13:02:51 -08:00

369 lines
15 KiB
C++

#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/python_call.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_resp.h>
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
#include <torch/csrc/distributed/rpc/rref_impl.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <torch/csrc/distributed/rpc/script_resp.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
using namespace torch::distributed::autograd;
std::shared_ptr<FutureMessage> RequestCallbackImpl::processRpc(
RpcCommandBase& rpc,
MessageType messageType,
int64_t messageId) const {
auto wrap = [messageId](Message m) {
m.setId(messageId);
return std::make_shared<FutureMessage>(std::move(m));
};
// TODO: RpcCommandBase should have an abstract execute() method that we can
// call here instead of having another switch statement here. Even better we
// could have abstract classes RpcRequest and RpcResp which inherit from
// RpcCommandBase and RpcRequest declares the abstract method execute() that
// we can call here. RpcResponse could have an abstract method to convert it
// to a python object.
switch (messageType) {
case MessageType::SCRIPT_CALL: {
auto& scriptCall = static_cast<ScriptCall&>(rpc);
// scriptCall is only alive within this block, use reference to avoid copy
auto& stack = scriptCall.stackRef();
if (scriptCall.hasOp()) {
scriptCall.op()->getOperation()(stack);
} else {
PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptCall.qualifiedName())
.run(stack);
}
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of "
"size ",
stack.size());
return wrap(std::move(ScriptResp(std::move(stack.front()))).toMessage());
}
case MessageType::PYTHON_CALL: {
auto& pyCall = static_cast<PythonCall&>(rpc);
std::vector<torch::Tensor> responseTensorTable;
auto payload = PythonRpcHandler::getInstance().generatePythonUDFResult(
pyCall.pickledPayload(), pyCall.tensors(), responseTensorTable);
return wrap(
std::move(
PythonResp(std::move(payload), std::move(responseTensorTable)))
.toMessage());
}
case MessageType::SCRIPT_REMOTE_CALL: {
auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
auto rrefId = scriptRemoteCall.retRRefId();
auto forkId = scriptRemoteCall.retForkId();
auto& ctx = RRefContext::getInstance();
TypePtr returnType;
if (scriptRemoteCall.hasOp()) {
returnType = scriptRemoteCall.op()->schema().returns()[0].type();
} else {
returnType = PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptRemoteCall.qualifiedName())
.getSchema()
.returns()
.at(0)
.type();
}
auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, returnType);
// TODO: make this asynchronous
// scriptRemoteCall is only alive within this block, use reference to
// avoid copy
auto& stack = scriptRemoteCall.stackRef();
if (scriptRemoteCall.hasOp()) {
scriptRemoteCall.op()->getOperation()(stack);
} else {
PythonRpcHandler::getInstance()
.jitCompilationUnit()
->get_function(scriptRemoteCall.qualifiedName())
.run(stack);
}
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
"Return value of a builtin operator or a "
"TorchScript function should be a single IValue, got a vector of "
"size ",
stack.size());
ownerRRef->setValue(std::move(stack.front()));
if (rrefId != forkId) {
// Caller is a user and callee is the owner, add fork
//
// NB: rrefId == forkId is true if and only if calling remote to self.
// In that case both the caller and the callee will access the
// OwnerRRef. Hence, on the callee side (here), it should not call
// addForkOfOwner as it is not a fork. To allow callee to distinguish
// when this request is sent to self, the caller will set forkId using
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
}
return wrap(RemoteRet(rrefId, forkId).toMessage());
}
case MessageType::PYTHON_REMOTE_CALL: {
auto& prc = static_cast<PythonRemoteCall&>(rpc);
auto rrefId = RRefId::fromIValue(prc.retRRefId());
auto forkId = ForkId::fromIValue(prc.retForkId());
auto& ctx = RRefContext::getInstance();
auto ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, PyObjectType::get());
IValue py_ivalue = jit::toIValue(
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()),
PyObjectType::get());
ownerRRef->setValue(std::move(py_ivalue));
if (rrefId != forkId) {
// Caller is a user and callee is the owner, add fork
//
// NB: rrefId == forkId is true if and only if calling remote to self.
// In that case both the caller and the callee will access the
// OwnerRRef. Hence, on the callee side (here), it should not call
// addForkOfOwner as it is not a fork. To allow callee to distinguish
// when this request is sent to self, the caller will set forkId using
// rrefId (OwnerRRef does not have a forkId anyway).
ctx.addForkOfOwner(rrefId, forkId);
}
return wrap(RemoteRet(rrefId, forkId).toMessage());
}
case MessageType::SCRIPT_RREF_FETCH_CALL: {
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
c10::intrusive_ptr<OwnerRRef> rref =
ctx.getOwnerRRef(srf.rrefId());
if (rref->hasValue()) { // optional fast-path
return wrap(ScriptRRefFetchRet({rref->getValue()}).toMessage());
}
auto whenValueSet = rref->getFuture();
auto responseFuture = std::make_shared<FutureMessage>();
// Our response is satisfied when the rpcs come back.
whenValueSet->addCallback(
[responseFuture, messageId, rref](
const rpc::Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (!error) {
Message m = ScriptRRefFetchRet({rref->getValue()}).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} else {
responseFuture->setError(error->what());
}
});
return responseFuture;
}
case MessageType::PYTHON_RREF_FETCH_CALL: {
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
auto& ctx = RRefContext::getInstance();
c10::intrusive_ptr<OwnerRRef> rref =
ctx.getOwnerRRef(prf.rrefId());
if (rref->hasValue()) { // optional fast-path
auto value = rref->getValue();
py::object pyValue;
{
pybind11::gil_scoped_acquire ag;
pyValue = torch::jit::toPyObject(std::move(value));
}
SerializedPyObj result =
PythonRpcHandler::getInstance().serialize(pyValue);
return wrap(PythonRRefFetchRet(result.toIValues()).toMessage());
}
auto whenValueSet = rref->getFuture();
auto responseFuture = std::make_shared<FutureMessage>();
// Our response is satisfied when the rpcs come back.
whenValueSet->addCallback(
[responseFuture, messageId, rref](
const rpc::Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (!error) {
auto value = rref->getValue();
py::object pyValue;
{
pybind11::gil_scoped_acquire ag;
pyValue = torch::jit::toPyObject(std::move(value));
}
SerializedPyObj result =
PythonRpcHandler::getInstance().serialize(pyValue);
Message m = PythonRRefFetchRet(result.toIValues()).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} else {
responseFuture->setError(error->what());
}
});
return responseFuture;
}
case MessageType::RREF_USER_DELETE: {
auto& rud = static_cast<RRefUserDelete&>(rpc);
auto& ctx = RRefContext::getInstance();
auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
if (deletedRRef && deletedRRef->isPyObj()) {
pybind11::gil_scoped_acquire ag;
deletedRRef.reset();
}
return wrap(std::move(RRefAck()).toMessage());
}
case MessageType::RREF_CHILD_ACCEPT: {
auto& rca = static_cast<RRefChildAccept&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.delPendingChild(rca.forkId());
return wrap(std::move(RRefAck()).toMessage());
}
case MessageType::RREF_FORK_REQUEST: {
auto& rfr = static_cast<RRefForkRequest&>(rpc);
auto& ctx = RRefContext::getInstance();
ctx.addForkOfOwner(rfr.rrefId(), rfr.forkId());
return wrap(RRefAck().toMessage());
}
case MessageType::FORWARD_AUTOGRAD_REQ: {
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
// Attach 'recv' autograd function.
auto autogradContext = addRecvRpcBackward(
rpcWithAutograd.autogradMetadata(),
rpcWithAutograd.tensors(),
rpcWithAutograd.fromWorkerId());
// For this recv thread on server side, before processRpc(),
// set current_context_id_ to be context_id passed from client.
// In this way, if there is nested rpc call in python rpc call, original
// context_id from client can be passed in the chain calls.
auto& autogradContainer = DistAutogradContainer::getInstance();
TORCH_INTERNAL_ASSERT(
autogradContext != nullptr,
"autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get "
"or create valid autogradContext in addRecvRpcBackward.");
autogradContainer.setCurrentContextId(autogradContext->contextId());
// Process the original RPC.
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
// Kick off processing for the nested future and get a Future<T> to the
// result.
auto wrappedRpcResponseFuture = processRpc(
rpcWithAutograd.wrappedRpc(), wrappedMessageType, messageId);
// Make an overall future for the wrapped response.
auto responseFuture = std::make_shared<rpc::FutureMessage>();
auto fromWorkerId = rpcWithAutograd.fromWorkerId();
// The original future needs to be marked as completed when the wrapped
// one completes, with the autograd context information wrapped.
wrappedRpcResponseFuture->addCallback(
[responseFuture, messageId, fromWorkerId, wrappedRpcResponseFuture](
const Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (error) {
// Propagate error to responseFuture if we had one.
responseFuture->setError(error->what());
} else {
auto msg = getMessageWithAutograd(
fromWorkerId,
std::move(*wrappedRpcResponseFuture).moveValue(),
MessageType::FORWARD_AUTOGRAD_RESP);
msg.setId(messageId);
responseFuture->markCompleted(std::move(msg));
}
});
return responseFuture;
}
case MessageType::BACKWARD_AUTOGRAD_REQ: {
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
// Retrieve the appropriate autograd context.
auto autogradContext =
DistAutogradContainer::getInstance().retrieveContext(
autogradMetadata.autogradContextId);
// Lookup the appropriate 'send' function to enqueue.
std::shared_ptr<SendRpcBackward> sendFunction =
autogradContext->retrieveSendFunction(
autogradMetadata.autogradMessageId);
// Attach the gradients to the send function.
sendFunction->setGrads(gradientsCall.getGrads());
auto responseFuture = std::make_shared<rpc::FutureMessage>();
// Now execute the autograd graph using the "distributed engine."
auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
autogradContext, sendFunction, gradientsCall.retainGraph());
// Our response is satisfied when the rpcs come back.
execFuture->addCallback(
[responseFuture, messageId](
const Message& /* unused */,
const c10::optional<utils::FutureError>& error) {
if (!error) {
Message m = std::move(PropagateGradientsResp()).toMessage();
m.setId(messageId);
responseFuture->markCompleted(std::move(m));
} else {
responseFuture->setError(error->what());
}
});
return responseFuture;
};
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
auto cleanupContextId = cleanupContextReq.getContextId();
// release the context if it still exists on this thread. We need to
// check if it exists since it may have been deleted by an in-flight
// RPC. This can create nested RPCs if there are other nodes that get
// notified to clean up their context.
DistAutogradContainer::getInstance().releaseContextIfPresent(
cleanupContextId);
return wrap(std::move(CleanupAutogradContextResp()).toMessage());
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", messageType, " not supported.");
}
}
}
std::shared_ptr<FutureMessage> RequestCallbackImpl::processMessage(
Message& request) const {
std::unique_ptr<RpcCommandBase> rpc = deserializeRequest(request);
return processRpc(*rpc, request.type(), request.id());
}
} // namespace rpc
} // namespace distributed
} // namespace torch