mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022 This change implements the "FAST" mode distributed autograd backward pass as described in https://github.com/pytorch/pytorch/issues/23110. At a high level the backward pass works as follows: 1. We start by computing dependencies on the node that calls `torch.distributed.backward`. 2. This node computes the dependencies starting from the root nodes provided in the backward call and all the 'send' functions present in the current autograd context. The "FAST" mode assumes all 'send' functions are part of the autograd computation. 3. Once the dependency computation is done, the distributed autograd engine calls the local autograd engine to execute the autograd graph. Note that the autograd graph on a single node is not necessarily connected because of inter-node communication. As a result, we have special handling to ensure the local autograd engine ensures we execute the entire graph starting from the provided roots and all 'send' functions on the node. 4. When the local autograd engine hits a 'recv' function, it performs an async RPC to send the gradients over to the appropriate node and stores a future in the autograd context to keep track of this RPC. 5. On the destination node, the appropriate 'send' function is looked up and enqueued on the local autograd engine. If this is the first time the node is hearing about this autograd context id on the backward pass, then the node computes dependencies for the local autograd engine. 6. As part of compute dependencies, the distributed autograd engine discovers all leaf nodes and ensures those are passed as 'outputs' to the local autograd engine. This avoids running the 'AccumulateGrad' function. 7. The gradients computed for the leaf nodes are then actually accumulated in `DistAutogradContext` for the appropriate autograd context id. 8. The distributed autograd engine waits for the local autograd engine to complete and also waits for all the 'Futures' (stored in 4.) for respective RPCs to finish. We have made the following changes to the local autograd engine for this purpose: 1. Expose GraphTask and NodeTask so that the distributed autograd engine can use them. 2. Expose a `execute_with_graph_task` API which gives the distributed engine to build a GraphTask and pass it to the local autograd engine. 3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build a `NodeTask` for a 'send' function and enqueue it on the local autograd engine. In addition to this a few general improvements: 1. Added a `PropagateGradients` RPC call for the 'recv' function to pass gradients to the appropriate node during the backward pass. 2. Use IValues as much as possible in serialization for RpcWithAutograd. 3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate exception instead of just returning the message. This is inline with what most Future.wait() APIs do. 4. Added a `get_gradients(context_id)` API which allows users to retrieve a map from Tensor to respective gradient for the provided context_id on the local node. ghstack-source-id: 91794926 Test Plan: unit tests. Differential Revision: D17652615 fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
215 lines
8.7 KiB
C++
215 lines
8.7 KiB
C++
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
|
|
#include <c10/util/C++17.h>
|
|
#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
|
|
#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
|
|
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/future_message.h>
|
|
#include <torch/csrc/distributed/rpc/python_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_call.h>
|
|
#include <torch/csrc/distributed/rpc/python_udf_resp.h>
|
|
#include <torch/csrc/distributed/rpc/rref.h>
|
|
#include <torch/csrc/distributed/rpc/rref_context.h>
|
|
#include <torch/csrc/distributed/rpc/rref_proto.h>
|
|
#include <torch/csrc/distributed/rpc/script_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_remote_call.h>
|
|
#include <torch/csrc/distributed/rpc/script_resp.h>
|
|
#include <torch/csrc/distributed/rpc/utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
using namespace torch::distributed::autograd;
|
|
|
|
std::unique_ptr<RpcCommandBase> RequestCallbackImpl::processRpc(
|
|
RpcCommandBase& rpc,
|
|
MessageType messageType) const {
|
|
// TODO: RpcCommandBase should have an abstract execute() method that we can
|
|
// call here instead of having another switch statement here. Even better we
|
|
// could have abstract classes RpcRequest and RpcResp which inherit from
|
|
// RpcCommandBase and RpcRequest declares the abstract method execute() that
|
|
// we can call here. RpcResponse could have an abstract method to convert it
|
|
// to a python object.
|
|
switch (messageType) {
|
|
case MessageType::SCRIPT_CALL: {
|
|
auto& scriptCall = static_cast<ScriptCall&>(rpc);
|
|
|
|
// sc is only alive within this block, use reference to avoid copy
|
|
auto& stack = scriptCall.stackRef();
|
|
scriptCall.op()->getOperation()(stack);
|
|
|
|
TORCH_INTERNAL_ASSERT(
|
|
stack.size() == 1,
|
|
"Return value of a builtin operator or a "
|
|
"TorchScript function should be a single IValue, got a vector of "
|
|
"size ",
|
|
stack.size());
|
|
|
|
return c10::guts::make_unique<ScriptResp>(std::move(stack.front()));
|
|
}
|
|
case MessageType::PYTHON_CALL: {
|
|
auto& pyCall = static_cast<PythonUDFCall&>(rpc);
|
|
std::vector<torch::Tensor> responseTensorTable;
|
|
auto payload = PythonRpcHandler::getInstance().generatePythonUDFResult(
|
|
pyCall.pickledPayload(), pyCall.tensors(), responseTensorTable);
|
|
return c10::guts::make_unique<PythonUDFResp>(
|
|
std::move(payload), std::move(responseTensorTable));
|
|
}
|
|
case MessageType::SCRIPT_REMOTE_CALL: {
|
|
auto& src = static_cast<ScriptRemoteCall&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
|
|
auto ownerRRef = ctx.getOrCreateOwnerRRef<IValue>(src.retRRefId());
|
|
|
|
// TODO: make this asynchronous
|
|
// src is only alive within this block, use reference to avoid copy
|
|
auto& stack = src.stackRef();
|
|
src.op()->getOperation()(stack);
|
|
TORCH_INTERNAL_ASSERT(
|
|
stack.size() == 1,
|
|
"Return value of a builtin operator or a "
|
|
"TorchScript function should be a single IValue, got a vector of "
|
|
"size ",
|
|
stack.size());
|
|
|
|
ownerRRef->setValue(std::move(stack.front()));
|
|
ctx.addForkOfOwner(src.retRRefId(), src.retForkId());
|
|
return c10::guts::make_unique<RemoteRet>(
|
|
src.retRRefId(), src.retForkId());
|
|
}
|
|
case MessageType::PYTHON_REMOTE_CALL: {
|
|
auto& prc = static_cast<PythonRemoteCall&>(rpc);
|
|
|
|
auto rrefId = RRefId::fromIValue(prc.retRRefId());
|
|
auto forkId = ForkId::fromIValue(prc.retForkId());
|
|
auto& ctx = RRefContext::getInstance();
|
|
|
|
auto ownerRRef = ctx.getOrCreateOwnerRRef<py::object>(rrefId);
|
|
ownerRRef->setValue(
|
|
PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()));
|
|
ctx.addForkOfOwner(rrefId, forkId);
|
|
return c10::guts::make_unique<RemoteRet>(rrefId, forkId);
|
|
}
|
|
case MessageType::SCRIPT_RREF_FETCH_CALL: {
|
|
auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
// TODO: make this asynchronous
|
|
std::shared_ptr<OwnerRRef<IValue>> rref =
|
|
ctx.getOrCreateOwnerRRef<IValue>(srf.rrefId());
|
|
return c10::guts::make_unique<RRefFetchRet>(
|
|
RRefFetchRet({rref->getValue()}));
|
|
}
|
|
case MessageType::PYTHON_RREF_FETCH_CALL: {
|
|
auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
// TODO: make this asynchronous
|
|
std::shared_ptr<OwnerRRef<py::object>> rref =
|
|
ctx.getOrCreateOwnerRRef<py::object>(prf.rrefId());
|
|
SerializedPyObj result =
|
|
PythonRpcHandler::getInstance().serialize(rref->getValue());
|
|
return c10::guts::make_unique<RRefFetchRet>(
|
|
RRefFetchRet(result.toIValues()));
|
|
}
|
|
case MessageType::RREF_USER_DELETE: {
|
|
auto& rud = static_cast<RRefUserDelete&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
|
|
return c10::guts::make_unique<RRefAck>();
|
|
}
|
|
case MessageType::RREF_CHILD_ACCEPT: {
|
|
auto& rca = static_cast<RRefChildAccept&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
ctx.delPendingChild(rca.forkId());
|
|
return c10::guts::make_unique<RRefAck>();
|
|
}
|
|
case MessageType::RREF_FORK_REQUEST: {
|
|
auto& rfr = static_cast<RRefForkRequest&>(rpc);
|
|
auto& ctx = RRefContext::getInstance();
|
|
ctx.addForkOfOwner(rfr.rrefId(), rfr.forkId());
|
|
return c10::guts::make_unique<RRefAck>();
|
|
}
|
|
case MessageType::FORWARD_AUTOGRAD_REQ: {
|
|
auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
|
|
const auto& autogradMetadata = rpcWithAutograd.autogradMetadata();
|
|
|
|
// Attach 'recv' autograd function.
|
|
DistAutogradContext* autogradContext = addRecvRpcBackward(
|
|
rpcWithAutograd.autogradMetadata(),
|
|
rpcWithAutograd.tensors(),
|
|
rpcWithAutograd.fromWorkerId());
|
|
|
|
// Process the original RPC.
|
|
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
|
|
auto wrappedRpcResponse =
|
|
processRpc(rpcWithAutograd.wrappedRpc(), wrappedMessageType);
|
|
|
|
// Wrap the response with autograd, need a new autograd message id for
|
|
// each send/recv pair.
|
|
auto& autogradContainer = DistAutogradContainer::getInstance();
|
|
AutogradMetadata responseAutogradMetadata(
|
|
autogradMetadata.autogradContextId,
|
|
autogradContainer.newAutogradMessageId());
|
|
|
|
auto response = c10::guts::make_unique<RpcWithAutograd>(
|
|
rpc::RpcAgent::getDefaultRpcAgent()->getWorkerInfo().id_,
|
|
MessageType::FORWARD_AUTOGRAD_RESP,
|
|
responseAutogradMetadata,
|
|
std::move(wrappedRpcResponse));
|
|
|
|
// Attach the 'send' autograd function if needed.
|
|
if (autogradContext != nullptr) {
|
|
addSendRpcBackward(
|
|
*autogradContext, responseAutogradMetadata, response->tensors());
|
|
}
|
|
return std::move(response);
|
|
}
|
|
case MessageType::BACKWARD_AUTOGRAD_REQ: {
|
|
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
|
|
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
|
|
|
|
// Retrieve the appropriate autograd context.
|
|
auto& autogradContext =
|
|
DistAutogradContainer::getInstance().retrieveContext(
|
|
autogradMetadata.autogradContextId);
|
|
|
|
// Lookup the appropriate 'send' function to enqueue.
|
|
std::shared_ptr<SendRpcBackward> sendFunction =
|
|
autogradContext.retrieveSendFunction(
|
|
autogradMetadata.autogradMessageId);
|
|
|
|
// Attach the gradients to the send function.
|
|
sendFunction->setGrads(gradientsCall.getGrads());
|
|
|
|
// Now execute the autograd graph using the "distributed engine."
|
|
DistEngine::getInstance().executeSendFunction(
|
|
autogradContext, sendFunction);
|
|
|
|
return c10::guts::make_unique<PropagateGradientsResp>();
|
|
}
|
|
default: {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "Request type ", messageType, " not supported.");
|
|
}
|
|
}
|
|
}
|
|
|
|
Message RequestCallbackImpl::processMessage(Message& request) const {
|
|
std::unique_ptr<RpcCommandBase> rpc = deserializeRequest(request);
|
|
auto response = processRpc(*rpc, request.type());
|
|
if (response == nullptr) {
|
|
return Message();
|
|
}
|
|
auto responseMessage = std::move(*response).toMessage();
|
|
responseMessage.setId(request.id());
|
|
return responseMessage;
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|