From 5f67c923f1555f10d68a3deee2f6086eeaa78b60 Mon Sep 17 00:00:00 2001 From: Omkar Salpekar Date: Wed, 18 Mar 2020 18:52:08 -0700 Subject: [PATCH] [1.5 Release][Dist Autograd][Better Engineering] Notify Workers on Failure during Distributed Autograd (#34638) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638 Fixes: https://github.com/pytorch/pytorch/issues/27643 This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly. (Note: this ignores all push blocking failures!) Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass. Differential Revision: D20164420 fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2 --- caffe2/CMakeLists.txt | 2 + tools/build_variables.bzl | 2 + torch/csrc/autograd/engine.cpp | 2 +- .../autograd/context/container.cpp | 10 +++ .../distributed/autograd/context/container.h | 4 ++ .../distributed/autograd/context/context.cpp | 48 ++++++++++---- .../distributed/autograd/context/context.h | 17 +++++ .../autograd/engine/dist_engine.cpp | 15 ++++- .../dist_autograd_failure_req.cpp | 66 +++++++++++++++++++ .../rpc_messages/dist_autograd_failure_req.h | 30 +++++++++ .../dist_autograd_failure_resp.cpp | 23 +++++++ .../rpc_messages/dist_autograd_failure_resp.h | 24 +++++++ torch/csrc/distributed/rpc/message.cpp | 8 ++- torch/csrc/distributed/rpc/message.h | 3 + .../distributed/rpc/request_callback_impl.cpp | 46 ++++++++++++- torch/csrc/distributed/rpc/utils.cpp | 8 +++ .../distributed/rpc/dist_autograd_test.py | 42 ++++++++---- 17 files changed, 320 insertions(+), 30 deletions(-) create mode 100644 torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp create mode 100644 torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h create mode 100644 torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp create mode 100644 torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 7a59bd587bb1..f683ca6ffe8b 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -532,6 +532,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp + ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp + ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index cdd509d34329..e9ac74658f87 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -61,6 +61,8 @@ libtorch_sources = [ "torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp", "torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp", "torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp", + "torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp", + "torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp", "torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp", "torch/csrc/distributed/rpc/message.cpp", "torch/csrc/distributed/rpc/python_call.cpp", diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index db23e498ea5a..28a2ab53a492 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -820,7 +820,7 @@ void Engine::mark_graph_task_completed(std::shared_ptr& graph_task) { graph_task->future_result_->markCompleted( std::move(graph_task->captured_vars_)); } catch (std::exception& e) { - graph_task->future_result_->setError(e.what()); + graph_task->future_result_->setErrorIfNeeded(e.what()); } } diff --git a/torch/csrc/distributed/autograd/context/container.cpp b/torch/csrc/distributed/autograd/context/container.cpp index 681b9a8d22d0..3ee6a7faa8c6 100644 --- a/torch/csrc/distributed/autograd/context/container.cpp +++ b/torch/csrc/distributed/autograd/context/container.cpp @@ -204,6 +204,16 @@ ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) { return autograd_context_.at(context_id); } +ContextPtr DistAutogradContainer::retrieveContextIfPresent(int64_t context_id) { + std::lock_guard guard(autograd_context_lock_); + auto it = autograd_context_.find(context_id); + if (it != autograd_context_.end()) { + return it->second; + } else { + return nullptr; + } +} + int64_t DistAutogradContainer::getMaxId() { return max_id_; } diff --git a/torch/csrc/distributed/autograd/context/container.h b/torch/csrc/distributed/autograd/context/container.h index 1bd8dbc65ca0..56fef7a62997 100644 --- a/torch/csrc/distributed/autograd/context/container.h +++ b/torch/csrc/distributed/autograd/context/container.h @@ -52,6 +52,10 @@ class TORCH_API DistAutogradContainer { // Retrieve the autograd context for a given context_id. ContextPtr retrieveContext(int64_t context_id); + // Retrieve the autograd context for a given context_id if it exists, + // otherwise return nullptr. + ContextPtr retrieveContextIfPresent(int64_t context_id); + // Retrieves the currently active autograd context for the current thread. ContextPtr currentContext(); diff --git a/torch/csrc/distributed/autograd/context/context.cpp b/torch/csrc/distributed/autograd/context/context.cpp index 9559fdb275b9..8e600d78051f 100644 --- a/torch/csrc/distributed/autograd/context/context.cpp +++ b/torch/csrc/distributed/autograd/context/context.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch { namespace distributed { @@ -103,6 +104,14 @@ std::shared_ptr DistAutogradContext:: return graphTask_; } +std::shared_ptr DistAutogradContext:: + retrieveGraphTaskIfExists() { + // Similar to retrieveGraphTask() but does not throw an exception if + // GraphTask doesn't exist. + std::lock_guard guard(lock_); + return graphTask_; +} + void DistAutogradContext::setGraphTask( std::shared_ptr graphTask) { std::lock_guard guard(lock_); @@ -125,17 +134,7 @@ void DistAutogradContext::addOutstandingRpc( const c10::optional& futErr) { if (futErr) { // If we have an error, let the local autograd engine know about it. - std::runtime_error err((*futErr).what()); - std::unique_lock lock(lock_); - if (graphTask_) { - graphTask_->set_exception_without_signal(nullptr); - lock.unlock(); - graphTask_->future_result_->setErrorIfNeeded(err.what()); - } else { - LOG(WARNING) - << "Ignoring error since GraphTask is no longer valid: " - << err.what(); - } + setGraphTaskException((*futErr).what()); } }); std::lock_guard guard(lock_); @@ -193,6 +192,33 @@ std::shared_ptr DistAutogradContext:: return state->future; } +bool DistAutogradContext::setGraphTaskException(const std::string& errorMsg) { + std::unique_lock lock(lock_); + if (graphTask_) { + if (graphTask_->has_error_) { + lock.unlock(); + return true; + } + graphTask_->set_exception_without_signal(nullptr); + lock.unlock(); + graphTask_->future_result_->setErrorIfNeeded(errorMsg); + return false; + } else { + LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: " + << errorMsg; + return true; + } +} + +void DistAutogradContext::propagateAutogradError(const std::string& errorMsg) { + auto neighborNodes = getKnownWorkerIds(); + auto agent = rpc::RpcAgent::getCurrentRpcAgent(); + for (const auto& node : neighborNodes) { + DistAutogradFailureReq msg(contextId_, errorMsg); + agent->send(agent->getWorkerInfo(node), std::move(msg).toMessage()); + } +} + std::shared_ptr DistAutogradContext::retrieveSendFunction( int64_t autograd_message_id) { std::lock_guard guard(lock_); diff --git a/torch/csrc/distributed/autograd/context/context.h b/torch/csrc/distributed/autograd/context/context.h index d2b8c0cd9dba..e43e3db6d34a 100644 --- a/torch/csrc/distributed/autograd/context/context.h +++ b/torch/csrc/distributed/autograd/context/context.h @@ -67,6 +67,17 @@ class TORCH_API DistAutogradContext { // These are the different workers that this context has sent RPCs to. std::unordered_set getKnownWorkerIds() const; + // Propagates the Autograd Failure message to all known nodes. + void propagateAutogradError(const std::string& errorMsg); + + // Sets an error on the Graph Task (assuming the graphTask hasn't already + // been set with an error). Returns true if no further processing is + // necessary for the autograd failures (for example, no need to propagate the + // failure message further). This is the case if the graphTask was already + // marked with an exception or the graphTask was invalid. Returns false if + // the graphTask was set with an exception and further processing is required. + bool setGraphTaskException(const std::string& errorMsg); + private: friend class BackwardPassCleanupGuard; friend class DistEngine; @@ -82,6 +93,12 @@ class TORCH_API DistAutogradContext { // Retrieve the GraphTask. std::shared_ptr retrieveGraphTask(); + // An idempotent function for retrieving the GraphTask. Since notifying + // neighbors of errors during autograd is recursive, it is possible that this + // function is called multiple times. Thus, it should not crash if the + // GraphTask has already been set with an error by a previous RPC. + std::shared_ptr retrieveGraphTaskIfExists(); + // Set the appropriate graph task for the backward pass. Can be called only // once. void setGraphTask(std::shared_ptr graphTask); diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 838676ac1818..660a207e4fc4 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -347,10 +347,21 @@ void DistEngine::execute( BackwardPassCleanupGuard guard(autogradContext); + auto execFuture = + runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges); + // This callback propagates the dist autograd error to other nodes if it + // encounters a failure. + execFuture->addCallback( + [autogradContext]( + const rpc::Message& /* unused */, + const c10::optional& error) { + if (error) { + autogradContext->propagateAutogradError(error->what()); + } + }); // This needs to be blocking and as a result we wait for the future to // complete. - runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges) - ->wait(); + execFuture->wait(); // Wait for all of the outstanding rpcs to complete. autogradContext->clearAndWaitForOutstandingRpcsAsync()->wait(); diff --git a/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp new file mode 100644 index 000000000000..ff44cc5b12f9 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp @@ -0,0 +1,66 @@ +#include +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +DistAutogradFailureReq::DistAutogradFailureReq( + int64_t context_id, + std::string errorMsg) + : context_id_(context_id), errorMsg_(std::move(errorMsg)) {} + +rpc::Message DistAutogradFailureReq::toMessage() && { + std::vector ivalues; + // add context_id and errorMsg + ivalues.emplace_back(context_id_); + ivalues.emplace_back(errorMsg_); + + // Now pickle using JIT pickler. + std::vector tensorTable; + std::vector payload = + jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable); + + return rpc::Message( + std::move(payload), + std::move(tensorTable), + rpc::MessageType::DIST_AUTOGRAD_FAILURE_REQ); +} + +std::unique_ptr DistAutogradFailureReq::fromMessage( + const rpc::Message& message) { + // Unpickle the message and retrieve tupleElements. + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + IValue tuple = jit::unpickle( + payload, + payload_size, + *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), + &message.tensors()); + std::vector tupleElements = tuple.toTuple()->elements(); + + TORCH_INTERNAL_ASSERT(tupleElements.size() == 2); + + // recover errorMsg + std::string errorMsg = tupleElements.back().toString()->string(); + tupleElements.pop_back(); + + // recover context_id + int64_t context_id = tupleElements.back().toInt(); + tupleElements.pop_back(); + + return std::make_unique(context_id, errorMsg); +} + +int64_t DistAutogradFailureReq::getContextId() { + return context_id_; +} + +std::string DistAutogradFailureReq::getErrorMsg() { + return errorMsg_; +} + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h new file mode 100644 index 000000000000..0af188d51c6a --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +// Used to notify other workers when there is an autograd error. +class TORCH_API DistAutogradFailureReq : public rpc::RpcCommandBase { + public: + DistAutogradFailureReq(int64_t context_id, std::string errorMsg); + // Serialization and deserialization methods. + rpc::Message toMessage() && override; + static std::unique_ptr fromMessage( + const rpc::Message& message); + + int64_t getContextId(); + std::string getErrorMsg(); + + private: + int64_t context_id_; + std::string errorMsg_; +}; + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp new file mode 100644 index 000000000000..4f3a8096cce6 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp @@ -0,0 +1,23 @@ +#include + +namespace torch { +namespace distributed { +namespace autograd { + +rpc::Message DistAutogradFailureResp::toMessage() && { + std::vector tensors; + std::vector payload; + return rpc::Message( + std::move(payload), + std::move(tensors), + rpc::MessageType::DIST_AUTOGRAD_FAILURE_RESP); +} + +std::unique_ptr DistAutogradFailureResp::fromMessage( + const rpc::Message& message /* unused */) { + return std::unique_ptr(); +} + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h new file mode 100644 index 000000000000..5cfb43b82ed3 --- /dev/null +++ b/torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +// Empty response for DistAutogradFailureReq. Send to acknowledge receipt of +// a DistAutogradFailureReq. +class TORCH_API DistAutogradFailureResp : public rpc::RpcCommandBase { + public: + DistAutogradFailureResp() = default; + // Serialization and deserialization methods. + rpc::Message toMessage() && override; + static std::unique_ptr fromMessage( + const rpc::Message& message); +}; + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index a62538004345..625577502186 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -86,7 +86,9 @@ bool Message::isRequest() const { MessageType::BACKWARD_AUTOGRAD_REQ == type_ || MessageType::FORWARD_AUTOGRAD_REQ == type_ || // Cleanup Autograd context request - MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_; + MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ || + // Autograd Backward Error Notification request + MessageType::DIST_AUTOGRAD_FAILURE_REQ == type_; } bool Message::isResponse() const { @@ -101,7 +103,9 @@ bool Message::isResponse() const { MessageType::BACKWARD_AUTOGRAD_RESP == type_ || MessageType::FORWARD_AUTOGRAD_RESP == type_ || // Cleanup autograd context response - MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_; + MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ || + // Autograd Backward Error Notification response + MessageType::DIST_AUTOGRAD_FAILURE_RESP == type_; } int64_t Message::id() const { diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index 9799e5b90f7a..7501eaffbaae 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -44,6 +44,9 @@ enum MessageType { CLEANUP_AUTOGRAD_CONTEXT_REQ = 19, CLEANUP_AUTOGRAD_CONTEXT_RESP = 20, + DIST_AUTOGRAD_FAILURE_REQ = 21, + DIST_AUTOGRAD_FAILURE_RESP = 22, + // Other internal message types EXCEPTION = 55, UNKNOWN = 60 diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp index 643fe8a98a72..d9fd2d8076c3 100644 --- a/torch/csrc/distributed/rpc/request_callback_impl.cpp +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include #include @@ -396,11 +398,22 @@ void RequestCallbackImpl::processRpc( case MessageType::BACKWARD_AUTOGRAD_REQ: { auto& gradientsCall = static_cast(rpc); const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); + std::shared_ptr autogradContext; - // Retrieve the appropriate autograd context. - auto autogradContext = - DistAutogradContainer::getInstance().retrieveContext( + // In rare cases, a BACKWARD_AUTOGRAD_REQ may arrive after the + // context has been cleaned up due to an error during the backward pass. + // In such situations, we can ignore this message since no further + // gradient computations should take place on this context. + autogradContext = + DistAutogradContainer::getInstance().retrieveContextIfPresent( autogradMetadata.autogradContextId); + if (!autogradContext) { + LOG(INFO) << "Ignoring Backward Autograd Request. Context " + << autogradMetadata.autogradContextId + << " already cleaned up due to autograd error"; + markComplete(std::move(PropagateGradientsResp()).toMessage()); + return; + } // Lookup the appropriate 'send' function to enqueue. std::shared_ptr sendFunction = @@ -441,6 +454,33 @@ void RequestCallbackImpl::processRpc( markComplete(std::move(CleanupAutogradContextResp()).toMessage()); return; } + case MessageType::DIST_AUTOGRAD_FAILURE_REQ: { + auto& backwardFailureReq = static_cast(rpc); + auto errorContextId = backwardFailureReq.getContextId(); + auto errorMsg = backwardFailureReq.getErrorMsg(); + // Mark the given context's graphTask with an error if the context has + // not been cleaned up yet. Then use the given error message and + // propagate it to all known neighbor workers. + std::shared_ptr autogradContext; + autogradContext = + DistAutogradContainer::getInstance().retrieveContextIfPresent( + errorContextId); + if (!autogradContext) { + LOG(INFO) << "Ignoring Dist Autograd Failure Request. Context " + << errorContextId + << " already cleaned up due to autograd error"; + markComplete(std::move(DistAutogradFailureResp()).toMessage()); + return; + } + // Mark the graphTask with an exception. + bool graphTaskHadError = autogradContext->setGraphTaskException(errorMsg); + // Propagate the dist autograd failure message to known workers. + if (!graphTaskHadError) { + autogradContext->propagateAutogradError(errorMsg); + } + markComplete(std::move(DistAutogradFailureResp()).toMessage()); + return; + } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", messageType, " not supported."); diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index bdbd9e18c302..4302612b1059 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include #include #include #include @@ -58,6 +60,9 @@ std::unique_ptr deserializeRequest(const Message& request) { case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { return autograd::CleanupAutogradContextReq::fromMessage(request); } + case MessageType::DIST_AUTOGRAD_FAILURE_REQ: { + return autograd::DistAutogradFailureReq::fromMessage(request); + } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", request.type(), " not supported."); @@ -109,6 +114,9 @@ std::unique_ptr deserializeResponse( case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: { return autograd::CleanupAutogradContextResp::fromMessage(response); } + case MessageType::DIST_AUTOGRAD_FAILURE_RESP: { + return autograd::DistAutogradFailureResp::fromMessage(response); + } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", response.type(), " not supported."); diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 3005a1f43b5b..ffb23df3c6f7 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -1179,6 +1179,37 @@ class DistAutogradTest(RpcAgentTestFixture): # Run backwards, and validate we receive an error. dist_autograd.backward(context_id, [val.sum()]) + @dist_init + def test_backward_intermediate_autograd_engine_error(self): + with dist_autograd.context() as context_id: + t1 = torch.rand((3, 3), requires_grad=True) + t2 = torch.rand((3, 3), requires_grad=True) + # Perform some ops before error simulation. + tmp = (t1 + t2) * (t1 + t2) + t3 = SimulateBackwardError.apply(tmp) + + # Run multiple round trips across different nodes and verify the + # original node receives an error thrown at some intermediate node + # in the chain. + val = rpc.rpc_sync( + "worker{}".format(self._next_rank()), torch.add, args=(t2, t1) + ) + val = rpc.rpc_sync( + "worker{}".format(self._next_rank()), torch.mul, args=(val, t3) + ) + val = rpc.rpc_sync( + "worker{}".format(self._next_rank()), torch.matmul, args=(val, t2) + ) + val = rpc.rpc_sync( + "worker{}".format(self._next_rank()), torch.div, args=(val, t2) + ) + + with self.assertRaises(RuntimeError): + # Run backwards, and validate we receive an error. + dist_autograd.backward(context_id, [val.sum()]) + + self.assertTrue(_all_contexts_cleaned_up()) + @dist_init(clean_shutdown=False) @unittest.skipIf( IS_MACOS, @@ -1570,17 +1601,6 @@ class DistAutogradTest(RpcAgentTestFixture): ): dist_autograd.backward(context_id, [t1.sum()]) - # HACK: Killing workers since otherwise the autograd engine gets stuck on - # other nodes. The proper fix would be addressing: - # https://github.com/pytorch/pytorch/issues/27643, which would inform - # other nodes about the failure. - # The autograd engine gets stuck on other nodes since they're waiting to - # receive gradients from the node that received an error (and as a - # result it didn't execute the rest of the graph). - dist.barrier() - rpc.shutdown(graceful=False) - sys.exit(0) - @classmethod def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights): embedding = embedding_rref.local_value()