Files
pytorch/torch/csrc/distributed/autograd/utils.cpp
Pritam Damania 77bb41c965 Rename dist_autograd_context and dist_autograd_container. (#29696)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29696

The paths distributed/autograd/context/dist_autograd_context.h and
distributed/autograd/context/dist_autograd_container.h were repetitive.

Therefore renaming these to distributed/autograd/context/context.h and
distributed/autograd/context/container.h
ghstack-source-id: 93850266

Test Plan: waitforbuildbot

Differential Revision: D18467624

fbshipit-source-id: bbf3905396f553006851af296c880c1bd106ec47
2019-11-14 14:49:34 -08:00

125 lines
4.3 KiB
C++

#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
namespace torch {
namespace distributed {
namespace autograd {
using torch::distributed::autograd::AutogradMetadata;
using torch::distributed::autograd::RpcWithAutograd;
using torch::distributed::rpc::FutureMessage;
using torch::distributed::rpc::Message;
using torch::distributed::rpc::MessageType;
using torch::distributed::rpc::RpcAgent;
using torch::distributed::rpc::RpcCommandBase;
using torch::distributed::rpc::WorkerInfo;
void addSendRpcBackward(
DistAutogradContext& autogradContext,
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
const rpc::worker_id_t dst) {
// Attach the appropriate autograd edges.
auto grad_fn = std::make_shared<SendRpcBackward>();
grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors));
// Add the appropriate input metadata for the grad_fn.
for (const auto& tensor : tensors) {
grad_fn->add_input_metadata(tensor);
}
// Record the send autograd function in our current context.
autogradContext.addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
// Record the workerID
autogradContext.addKnownWorkerId(dst);
}
DistAutogradContext* addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
DistAutogradContext& autogradContext =
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
if (!tensors.empty()) {
TORCH_INTERNAL_ASSERT(
torch::autograd::compute_requires_grad(tensors),
"Received tensors do not require grad, addRecvRpcBackward should not be called");
// Attach the tensors as inputs to the autograd function.
auto grad_fn = std::make_shared<RecvRpcBackward>(
autogradMetadata, autogradContext, fromWorkerId);
for (auto& tensor : tensors) {
torch::autograd::set_history(tensor, grad_fn);
}
// Now update the autograd context with the necessary information.
autogradContext.addRecvFunction(
grad_fn, autogradMetadata.autogradMessageId);
}
return &autogradContext;
}
Message getMessageWithAutograd(
const rpc::worker_id_t dstId,
torch::distributed::rpc::Message&& wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording) {
auto& autogradContainer = DistAutogradContainer::getInstance();
// If there is no valid context and no tensor requires grads, send original
// rpc message. otherwise, attach grad info and grad functions and send
// rpcWithAutograd message.
auto tensorsRequireGrad =
torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
if (!autogradContainer.hasValidContext() ||
(!forceGradRecording && !tensorsRequireGrad)) {
return std::move(wrappedRpcMsg);
}
// Retrieve the appropriate context to modify.
auto& autogradContext = autogradContainer.currentContext();
// Wrap the original rpc with autograd information.
AutogradMetadata autogradMetadata(
autogradContext.contextId(), autogradContainer.newAutogradMessageId());
auto rpcWithAutograd = c10::guts::make_unique<RpcWithAutograd>(
RpcAgent::getDefaultRpcAgent()->getWorkerInfo().id_,
msgType,
autogradMetadata,
std::move(wrappedRpcMsg));
if (tensorsRequireGrad) {
// Record autograd information for 'send'.
addSendRpcBackward(
autogradContext, autogradMetadata, rpcWithAutograd->tensors(), dstId);
}
return std::move(*rpcWithAutograd).toMessage();
}
std::shared_ptr<FutureMessage> sendMessageWithAutograd(
RpcAgent& agent,
const WorkerInfo& dst,
torch::distributed::rpc::Message&& wrappedRpcMsg,
bool forceGradRecording) {
auto msg = getMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording);
return agent.send(dst, std::move(msg));
}
} // namespace autograd
} // namespace distributed
} // namespace torch