[1.5 Release][Dist Autograd][Better Engineering] Notify Workers on Failure during Distributed Autograd (#34638)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638

Fixes: https://github.com/pytorch/pytorch/issues/27643

This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly.

(Note: this ignores all push blocking failures!)

Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass.

Differential Revision: D20164420

fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
This commit is contained in:
Omkar Salpekar
2020-03-18 18:52:08 -07:00
committed by Facebook GitHub Bot
parent a73dfcf8cf
commit 5f67c923f1
17 changed files with 320 additions and 30 deletions

View File

@ -532,6 +532,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp

View File

@ -61,6 +61,8 @@ libtorch_sources = [
"torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp",
"torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp",
"torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp",
"torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp",
"torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp",
"torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp",
"torch/csrc/distributed/rpc/message.cpp",
"torch/csrc/distributed/rpc/python_call.cpp",

View File

@ -820,7 +820,7 @@ void Engine::mark_graph_task_completed(std::shared_ptr<GraphTask>& graph_task) {
graph_task->future_result_->markCompleted(
std::move(graph_task->captured_vars_));
} catch (std::exception& e) {
graph_task->future_result_->setError(e.what());
graph_task->future_result_->setErrorIfNeeded(e.what());
}
}

View File

@ -204,6 +204,16 @@ ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
return autograd_context_.at(context_id);
}
ContextPtr DistAutogradContainer::retrieveContextIfPresent(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
auto it = autograd_context_.find(context_id);
if (it != autograd_context_.end()) {
return it->second;
} else {
return nullptr;
}
}
int64_t DistAutogradContainer::getMaxId() {
return max_id_;
}

View File

@ -52,6 +52,10 @@ class TORCH_API DistAutogradContainer {
// Retrieve the autograd context for a given context_id.
ContextPtr retrieveContext(int64_t context_id);
// Retrieve the autograd context for a given context_id if it exists,
// otherwise return nullptr.
ContextPtr retrieveContextIfPresent(int64_t context_id);
// Retrieves the currently active autograd context for the current thread.
ContextPtr currentContext();

View File

@ -3,6 +3,7 @@
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/distributed/autograd/context/context.h>
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
namespace torch {
namespace distributed {
@ -103,6 +104,14 @@ std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
return graphTask_;
}
std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
retrieveGraphTaskIfExists() {
// Similar to retrieveGraphTask() but does not throw an exception if
// GraphTask doesn't exist.
std::lock_guard<std::mutex> guard(lock_);
return graphTask_;
}
void DistAutogradContext::setGraphTask(
std::shared_ptr<torch::autograd::GraphTask> graphTask) {
std::lock_guard<std::mutex> guard(lock_);
@ -125,17 +134,7 @@ void DistAutogradContext::addOutstandingRpc(
const c10::optional<utils::FutureError>& futErr) {
if (futErr) {
// If we have an error, let the local autograd engine know about it.
std::runtime_error err((*futErr).what());
std::unique_lock<std::mutex> lock(lock_);
if (graphTask_) {
graphTask_->set_exception_without_signal(nullptr);
lock.unlock();
graphTask_->future_result_->setErrorIfNeeded(err.what());
} else {
LOG(WARNING)
<< "Ignoring error since GraphTask is no longer valid: "
<< err.what();
}
setGraphTaskException((*futErr).what());
}
});
std::lock_guard<std::mutex> guard(lock_);
@ -193,6 +192,33 @@ std::shared_ptr<rpc::FutureMessage> DistAutogradContext::
return state->future;
}
bool DistAutogradContext::setGraphTaskException(const std::string& errorMsg) {
std::unique_lock<std::mutex> lock(lock_);
if (graphTask_) {
if (graphTask_->has_error_) {
lock.unlock();
return true;
}
graphTask_->set_exception_without_signal(nullptr);
lock.unlock();
graphTask_->future_result_->setErrorIfNeeded(errorMsg);
return false;
} else {
LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
<< errorMsg;
return true;
}
}
void DistAutogradContext::propagateAutogradError(const std::string& errorMsg) {
auto neighborNodes = getKnownWorkerIds();
auto agent = rpc::RpcAgent::getCurrentRpcAgent();
for (const auto& node : neighborNodes) {
DistAutogradFailureReq msg(contextId_, errorMsg);
agent->send(agent->getWorkerInfo(node), std::move(msg).toMessage());
}
}
std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
int64_t autograd_message_id) {
std::lock_guard<std::mutex> guard(lock_);

View File

@ -67,6 +67,17 @@ class TORCH_API DistAutogradContext {
// These are the different workers that this context has sent RPCs to.
std::unordered_set<rpc::worker_id_t> getKnownWorkerIds() const;
// Propagates the Autograd Failure message to all known nodes.
void propagateAutogradError(const std::string& errorMsg);
// Sets an error on the Graph Task (assuming the graphTask hasn't already
// been set with an error). Returns true if no further processing is
// necessary for the autograd failures (for example, no need to propagate the
// failure message further). This is the case if the graphTask was already
// marked with an exception or the graphTask was invalid. Returns false if
// the graphTask was set with an exception and further processing is required.
bool setGraphTaskException(const std::string& errorMsg);
private:
friend class BackwardPassCleanupGuard;
friend class DistEngine;
@ -82,6 +93,12 @@ class TORCH_API DistAutogradContext {
// Retrieve the GraphTask.
std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask();
// An idempotent function for retrieving the GraphTask. Since notifying
// neighbors of errors during autograd is recursive, it is possible that this
// function is called multiple times. Thus, it should not crash if the
// GraphTask has already been set with an error by a previous RPC.
std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTaskIfExists();
// Set the appropriate graph task for the backward pass. Can be called only
// once.
void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask);

View File

@ -347,10 +347,21 @@ void DistEngine::execute(
BackwardPassCleanupGuard guard(autogradContext);
auto execFuture =
runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges);
// This callback propagates the dist autograd error to other nodes if it
// encounters a failure.
execFuture->addCallback(
[autogradContext](
const rpc::Message& /* unused */,
const c10::optional<torch::utils::FutureError>& error) {
if (error) {
autogradContext->propagateAutogradError(error->what());
}
});
// This needs to be blocking and as a result we wait for the future to
// complete.
runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
->wait();
execFuture->wait();
// Wait for all of the outstanding rpcs to complete.
autogradContext->clearAndWaitForOutstandingRpcsAsync()->wait();

View File

@ -0,0 +1,66 @@
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace autograd {
DistAutogradFailureReq::DistAutogradFailureReq(
int64_t context_id,
std::string errorMsg)
: context_id_(context_id), errorMsg_(std::move(errorMsg)) {}
rpc::Message DistAutogradFailureReq::toMessage() && {
std::vector<at::IValue> ivalues;
// add context_id and errorMsg
ivalues.emplace_back(context_id_);
ivalues.emplace_back(errorMsg_);
// Now pickle using JIT pickler.
std::vector<torch::Tensor> tensorTable;
std::vector<char> payload =
jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
return rpc::Message(
std::move(payload),
std::move(tensorTable),
rpc::MessageType::DIST_AUTOGRAD_FAILURE_REQ);
}
std::unique_ptr<DistAutogradFailureReq> DistAutogradFailureReq::fromMessage(
const rpc::Message& message) {
// Unpickle the message and retrieve tupleElements.
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
IValue tuple = jit::unpickle(
payload,
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
&message.tensors());
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
TORCH_INTERNAL_ASSERT(tupleElements.size() == 2);
// recover errorMsg
std::string errorMsg = tupleElements.back().toString()->string();
tupleElements.pop_back();
// recover context_id
int64_t context_id = tupleElements.back().toInt();
tupleElements.pop_back();
return std::make_unique<DistAutogradFailureReq>(context_id, errorMsg);
}
int64_t DistAutogradFailureReq::getContextId() {
return context_id_;
}
std::string DistAutogradFailureReq::getErrorMsg() {
return errorMsg_;
}
} // namespace autograd
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,30 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <vector>
namespace torch {
namespace distributed {
namespace autograd {
// Used to notify other workers when there is an autograd error.
class TORCH_API DistAutogradFailureReq : public rpc::RpcCommandBase {
public:
DistAutogradFailureReq(int64_t context_id, std::string errorMsg);
// Serialization and deserialization methods.
rpc::Message toMessage() && override;
static std::unique_ptr<DistAutogradFailureReq> fromMessage(
const rpc::Message& message);
int64_t getContextId();
std::string getErrorMsg();
private:
int64_t context_id_;
std::string errorMsg_;
};
} // namespace autograd
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,23 @@
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
namespace torch {
namespace distributed {
namespace autograd {
rpc::Message DistAutogradFailureResp::toMessage() && {
std::vector<torch::Tensor> tensors;
std::vector<char> payload;
return rpc::Message(
std::move(payload),
std::move(tensors),
rpc::MessageType::DIST_AUTOGRAD_FAILURE_RESP);
}
std::unique_ptr<DistAutogradFailureResp> DistAutogradFailureResp::fromMessage(
const rpc::Message& message /* unused */) {
return std::unique_ptr<DistAutogradFailureResp>();
}
} // namespace autograd
} // namespace distributed
} // namespace torch

View File

@ -0,0 +1,24 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <vector>
namespace torch {
namespace distributed {
namespace autograd {
// Empty response for DistAutogradFailureReq. Send to acknowledge receipt of
// a DistAutogradFailureReq.
class TORCH_API DistAutogradFailureResp : public rpc::RpcCommandBase {
public:
DistAutogradFailureResp() = default;
// Serialization and deserialization methods.
rpc::Message toMessage() && override;
static std::unique_ptr<DistAutogradFailureResp> fromMessage(
const rpc::Message& message);
};
} // namespace autograd
} // namespace distributed
} // namespace torch

View File

@ -86,7 +86,9 @@ bool Message::isRequest() const {
MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
MessageType::FORWARD_AUTOGRAD_REQ == type_ ||
// Cleanup Autograd context request
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_;
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ ||
// Autograd Backward Error Notification request
MessageType::DIST_AUTOGRAD_FAILURE_REQ == type_;
}
bool Message::isResponse() const {
@ -101,7 +103,9 @@ bool Message::isResponse() const {
MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
MessageType::FORWARD_AUTOGRAD_RESP == type_ ||
// Cleanup autograd context response
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_;
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ ||
// Autograd Backward Error Notification response
MessageType::DIST_AUTOGRAD_FAILURE_RESP == type_;
}
int64_t Message::id() const {

View File

@ -44,6 +44,9 @@ enum MessageType {
CLEANUP_AUTOGRAD_CONTEXT_REQ = 19,
CLEANUP_AUTOGRAD_CONTEXT_RESP = 20,
DIST_AUTOGRAD_FAILURE_REQ = 21,
DIST_AUTOGRAD_FAILURE_RESP = 22,
// Other internal message types
EXCEPTION = 55,
UNKNOWN = 60

View File

@ -6,6 +6,8 @@
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
@ -396,11 +398,22 @@ void RequestCallbackImpl::processRpc(
case MessageType::BACKWARD_AUTOGRAD_REQ: {
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
std::shared_ptr<DistAutogradContext> autogradContext;
// Retrieve the appropriate autograd context.
auto autogradContext =
DistAutogradContainer::getInstance().retrieveContext(
// In rare cases, a BACKWARD_AUTOGRAD_REQ may arrive after the
// context has been cleaned up due to an error during the backward pass.
// In such situations, we can ignore this message since no further
// gradient computations should take place on this context.
autogradContext =
DistAutogradContainer::getInstance().retrieveContextIfPresent(
autogradMetadata.autogradContextId);
if (!autogradContext) {
LOG(INFO) << "Ignoring Backward Autograd Request. Context "
<< autogradMetadata.autogradContextId
<< " already cleaned up due to autograd error";
markComplete(std::move(PropagateGradientsResp()).toMessage());
return;
}
// Lookup the appropriate 'send' function to enqueue.
std::shared_ptr<SendRpcBackward> sendFunction =
@ -441,6 +454,33 @@ void RequestCallbackImpl::processRpc(
markComplete(std::move(CleanupAutogradContextResp()).toMessage());
return;
}
case MessageType::DIST_AUTOGRAD_FAILURE_REQ: {
auto& backwardFailureReq = static_cast<DistAutogradFailureReq&>(rpc);
auto errorContextId = backwardFailureReq.getContextId();
auto errorMsg = backwardFailureReq.getErrorMsg();
// Mark the given context's graphTask with an error if the context has
// not been cleaned up yet. Then use the given error message and
// propagate it to all known neighbor workers.
std::shared_ptr<DistAutogradContext> autogradContext;
autogradContext =
DistAutogradContainer::getInstance().retrieveContextIfPresent(
errorContextId);
if (!autogradContext) {
LOG(INFO) << "Ignoring Dist Autograd Failure Request. Context "
<< errorContextId
<< " already cleaned up due to autograd error";
markComplete(std::move(DistAutogradFailureResp()).toMessage());
return;
}
// Mark the graphTask with an exception.
bool graphTaskHadError = autogradContext->setGraphTaskException(errorMsg);
// Propagate the dist autograd failure message to known workers.
if (!graphTaskHadError) {
autogradContext->propagateAutogradError(errorMsg);
}
markComplete(std::move(DistAutogradFailureResp()).toMessage());
return;
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", messageType, " not supported.");

View File

@ -2,6 +2,8 @@
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
@ -58,6 +60,9 @@ std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
return autograd::CleanupAutogradContextReq::fromMessage(request);
}
case MessageType::DIST_AUTOGRAD_FAILURE_REQ: {
return autograd::DistAutogradFailureReq::fromMessage(request);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Request type ", request.type(), " not supported.");
@ -109,6 +114,9 @@ std::unique_ptr<RpcCommandBase> deserializeResponse(
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
return autograd::CleanupAutogradContextResp::fromMessage(response);
}
case MessageType::DIST_AUTOGRAD_FAILURE_RESP: {
return autograd::DistAutogradFailureResp::fromMessage(response);
}
default: {
TORCH_INTERNAL_ASSERT(
false, "Response type ", response.type(), " not supported.");

View File

@ -1179,6 +1179,37 @@ class DistAutogradTest(RpcAgentTestFixture):
# Run backwards, and validate we receive an error.
dist_autograd.backward(context_id, [val.sum()])
@dist_init
def test_backward_intermediate_autograd_engine_error(self):
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some ops before error simulation.
tmp = (t1 + t2) * (t1 + t2)
t3 = SimulateBackwardError.apply(tmp)
# Run multiple round trips across different nodes and verify the
# original node receives an error thrown at some intermediate node
# in the chain.
val = rpc.rpc_sync(
"worker{}".format(self._next_rank()), torch.add, args=(t2, t1)
)
val = rpc.rpc_sync(
"worker{}".format(self._next_rank()), torch.mul, args=(val, t3)
)
val = rpc.rpc_sync(
"worker{}".format(self._next_rank()), torch.matmul, args=(val, t2)
)
val = rpc.rpc_sync(
"worker{}".format(self._next_rank()), torch.div, args=(val, t2)
)
with self.assertRaises(RuntimeError):
# Run backwards, and validate we receive an error.
dist_autograd.backward(context_id, [val.sum()])
self.assertTrue(_all_contexts_cleaned_up())
@dist_init(clean_shutdown=False)
@unittest.skipIf(
IS_MACOS,
@ -1570,17 +1601,6 @@ class DistAutogradTest(RpcAgentTestFixture):
):
dist_autograd.backward(context_id, [t1.sum()])
# HACK: Killing workers since otherwise the autograd engine gets stuck on
# other nodes. The proper fix would be addressing:
# https://github.com/pytorch/pytorch/issues/27643, which would inform
# other nodes about the failure.
# The autograd engine gets stuck on other nodes since they're waiting to
# receive gradients from the node that received an error (and as a
# result it didn't execute the rest of the graph).
dist.barrier()
rpc.shutdown(graceful=False)
sys.exit(0)
@classmethod
def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
embedding = embedding_rref.local_value()