mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29696 The paths distributed/autograd/context/dist_autograd_context.h and distributed/autograd/context/dist_autograd_container.h were repetitive. Therefore renaming these to distributed/autograd/context/context.h and distributed/autograd/context/container.h ghstack-source-id: 93850266 Test Plan: waitforbuildbot Differential Revision: D18467624 fbshipit-source-id: bbf3905396f553006851af296c880c1bd106ec47
125 lines
4.3 KiB
C++
125 lines
4.3 KiB
C++
#include <torch/csrc/autograd/functions/utils.h>
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
|
|
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace autograd {
|
|
|
|
using torch::distributed::autograd::AutogradMetadata;
|
|
using torch::distributed::autograd::RpcWithAutograd;
|
|
using torch::distributed::rpc::FutureMessage;
|
|
using torch::distributed::rpc::Message;
|
|
using torch::distributed::rpc::MessageType;
|
|
using torch::distributed::rpc::RpcAgent;
|
|
using torch::distributed::rpc::RpcCommandBase;
|
|
using torch::distributed::rpc::WorkerInfo;
|
|
|
|
void addSendRpcBackward(
|
|
DistAutogradContext& autogradContext,
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors,
|
|
const rpc::worker_id_t dst) {
|
|
// Attach the appropriate autograd edges.
|
|
auto grad_fn = std::make_shared<SendRpcBackward>();
|
|
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
|
|
|
|
// Add the appropriate input metadata for the grad_fn.
|
|
for (const auto& tensor : tensors) {
|
|
grad_fn->add_input_metadata(tensor);
|
|
}
|
|
|
|
// Record the send autograd function in our current context.
|
|
autogradContext.addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
|
|
// Record the workerID
|
|
autogradContext.addKnownWorkerId(dst);
|
|
}
|
|
|
|
DistAutogradContext* addRecvRpcBackward(
|
|
const AutogradMetadata& autogradMetadata,
|
|
std::vector<torch::Tensor>& tensors,
|
|
rpc::worker_id_t fromWorkerId) {
|
|
// Initialize autograd context if necessary.
|
|
auto& autogradContainer = DistAutogradContainer::getInstance();
|
|
DistAutogradContext& autogradContext =
|
|
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
|
|
|
|
if (!tensors.empty()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
torch::autograd::compute_requires_grad(tensors),
|
|
"Received tensors do not require grad, addRecvRpcBackward should not be called");
|
|
|
|
// Attach the tensors as inputs to the autograd function.
|
|
auto grad_fn = std::make_shared<RecvRpcBackward>(
|
|
autogradMetadata, autogradContext, fromWorkerId);
|
|
for (auto& tensor : tensors) {
|
|
torch::autograd::set_history(tensor, grad_fn);
|
|
}
|
|
|
|
// Now update the autograd context with the necessary information.
|
|
autogradContext.addRecvFunction(
|
|
grad_fn, autogradMetadata.autogradMessageId);
|
|
}
|
|
|
|
return &autogradContext;
|
|
}
|
|
|
|
Message getMessageWithAutograd(
|
|
const rpc::worker_id_t dstId,
|
|
torch::distributed::rpc::Message&& wrappedRpcMsg,
|
|
MessageType msgType,
|
|
bool forceGradRecording) {
|
|
auto& autogradContainer = DistAutogradContainer::getInstance();
|
|
|
|
// If there is no valid context and no tensor requires grads, send original
|
|
// rpc message. otherwise, attach grad info and grad functions and send
|
|
// rpcWithAutograd message.
|
|
auto tensorsRequireGrad =
|
|
torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
|
|
if (!autogradContainer.hasValidContext() ||
|
|
(!forceGradRecording && !tensorsRequireGrad)) {
|
|
return std::move(wrappedRpcMsg);
|
|
}
|
|
|
|
// Retrieve the appropriate context to modify.
|
|
auto& autogradContext = autogradContainer.currentContext();
|
|
|
|
// Wrap the original rpc with autograd information.
|
|
AutogradMetadata autogradMetadata(
|
|
autogradContext.contextId(), autogradContainer.newAutogradMessageId());
|
|
auto rpcWithAutograd = c10::guts::make_unique<RpcWithAutograd>(
|
|
RpcAgent::getDefaultRpcAgent()->getWorkerInfo().id_,
|
|
msgType,
|
|
autogradMetadata,
|
|
std::move(wrappedRpcMsg));
|
|
|
|
if (tensorsRequireGrad) {
|
|
// Record autograd information for 'send'.
|
|
addSendRpcBackward(
|
|
autogradContext, autogradMetadata, rpcWithAutograd->tensors(), dstId);
|
|
}
|
|
|
|
return std::move(*rpcWithAutograd).toMessage();
|
|
}
|
|
|
|
std::shared_ptr<FutureMessage> sendMessageWithAutograd(
|
|
RpcAgent& agent,
|
|
const WorkerInfo& dst,
|
|
torch::distributed::rpc::Message&& wrappedRpcMsg,
|
|
bool forceGradRecording) {
|
|
auto msg = getMessageWithAutograd(
|
|
dst.id_,
|
|
std::move(wrappedRpcMsg),
|
|
MessageType::FORWARD_AUTOGRAD_REQ,
|
|
forceGradRecording);
|
|
|
|
return agent.send(dst, std::move(msg));
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace distributed
|
|
} // namespace torch
|