mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1.5 Release][Dist Autograd][Better Engineering] Notify Workers on Failure during Distributed Autograd (#34638)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638 Fixes: https://github.com/pytorch/pytorch/issues/27643 This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly. (Note: this ignores all push blocking failures!) Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass. Differential Revision: D20164420 fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a73dfcf8cf
commit
5f67c923f1
@ -532,6 +532,8 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
|
||||
${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
|
||||
|
@ -61,6 +61,8 @@ libtorch_sources = [
|
||||
"torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp",
|
||||
"torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.cpp",
|
||||
"torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.cpp",
|
||||
"torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.cpp",
|
||||
"torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.cpp",
|
||||
"torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp",
|
||||
"torch/csrc/distributed/rpc/message.cpp",
|
||||
"torch/csrc/distributed/rpc/python_call.cpp",
|
||||
|
@ -820,7 +820,7 @@ void Engine::mark_graph_task_completed(std::shared_ptr<GraphTask>& graph_task) {
|
||||
graph_task->future_result_->markCompleted(
|
||||
std::move(graph_task->captured_vars_));
|
||||
} catch (std::exception& e) {
|
||||
graph_task->future_result_->setError(e.what());
|
||||
graph_task->future_result_->setErrorIfNeeded(e.what());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,6 +204,16 @@ ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
|
||||
return autograd_context_.at(context_id);
|
||||
}
|
||||
|
||||
ContextPtr DistAutogradContainer::retrieveContextIfPresent(int64_t context_id) {
|
||||
std::lock_guard<std::mutex> guard(autograd_context_lock_);
|
||||
auto it = autograd_context_.find(context_id);
|
||||
if (it != autograd_context_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t DistAutogradContainer::getMaxId() {
|
||||
return max_id_;
|
||||
}
|
||||
|
@ -52,6 +52,10 @@ class TORCH_API DistAutogradContainer {
|
||||
// Retrieve the autograd context for a given context_id.
|
||||
ContextPtr retrieveContext(int64_t context_id);
|
||||
|
||||
// Retrieve the autograd context for a given context_id if it exists,
|
||||
// otherwise return nullptr.
|
||||
ContextPtr retrieveContextIfPresent(int64_t context_id);
|
||||
|
||||
// Retrieves the currently active autograd context for the current thread.
|
||||
ContextPtr currentContext();
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/autograd/functions/accumulate_grad.h>
|
||||
#include <torch/csrc/distributed/autograd/context/context.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
@ -103,6 +104,14 @@ std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
|
||||
return graphTask_;
|
||||
}
|
||||
|
||||
std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
|
||||
retrieveGraphTaskIfExists() {
|
||||
// Similar to retrieveGraphTask() but does not throw an exception if
|
||||
// GraphTask doesn't exist.
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
return graphTask_;
|
||||
}
|
||||
|
||||
void DistAutogradContext::setGraphTask(
|
||||
std::shared_ptr<torch::autograd::GraphTask> graphTask) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
@ -125,17 +134,7 @@ void DistAutogradContext::addOutstandingRpc(
|
||||
const c10::optional<utils::FutureError>& futErr) {
|
||||
if (futErr) {
|
||||
// If we have an error, let the local autograd engine know about it.
|
||||
std::runtime_error err((*futErr).what());
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (graphTask_) {
|
||||
graphTask_->set_exception_without_signal(nullptr);
|
||||
lock.unlock();
|
||||
graphTask_->future_result_->setErrorIfNeeded(err.what());
|
||||
} else {
|
||||
LOG(WARNING)
|
||||
<< "Ignoring error since GraphTask is no longer valid: "
|
||||
<< err.what();
|
||||
}
|
||||
setGraphTaskException((*futErr).what());
|
||||
}
|
||||
});
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
@ -193,6 +192,33 @@ std::shared_ptr<rpc::FutureMessage> DistAutogradContext::
|
||||
return state->future;
|
||||
}
|
||||
|
||||
bool DistAutogradContext::setGraphTaskException(const std::string& errorMsg) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
if (graphTask_) {
|
||||
if (graphTask_->has_error_) {
|
||||
lock.unlock();
|
||||
return true;
|
||||
}
|
||||
graphTask_->set_exception_without_signal(nullptr);
|
||||
lock.unlock();
|
||||
graphTask_->future_result_->setErrorIfNeeded(errorMsg);
|
||||
return false;
|
||||
} else {
|
||||
LOG(WARNING) << "Ignoring error since GraphTask is no longer valid: "
|
||||
<< errorMsg;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
void DistAutogradContext::propagateAutogradError(const std::string& errorMsg) {
|
||||
auto neighborNodes = getKnownWorkerIds();
|
||||
auto agent = rpc::RpcAgent::getCurrentRpcAgent();
|
||||
for (const auto& node : neighborNodes) {
|
||||
DistAutogradFailureReq msg(contextId_, errorMsg);
|
||||
agent->send(agent->getWorkerInfo(node), std::move(msg).toMessage());
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
|
||||
int64_t autograd_message_id) {
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
@ -67,6 +67,17 @@ class TORCH_API DistAutogradContext {
|
||||
// These are the different workers that this context has sent RPCs to.
|
||||
std::unordered_set<rpc::worker_id_t> getKnownWorkerIds() const;
|
||||
|
||||
// Propagates the Autograd Failure message to all known nodes.
|
||||
void propagateAutogradError(const std::string& errorMsg);
|
||||
|
||||
// Sets an error on the Graph Task (assuming the graphTask hasn't already
|
||||
// been set with an error). Returns true if no further processing is
|
||||
// necessary for the autograd failures (for example, no need to propagate the
|
||||
// failure message further). This is the case if the graphTask was already
|
||||
// marked with an exception or the graphTask was invalid. Returns false if
|
||||
// the graphTask was set with an exception and further processing is required.
|
||||
bool setGraphTaskException(const std::string& errorMsg);
|
||||
|
||||
private:
|
||||
friend class BackwardPassCleanupGuard;
|
||||
friend class DistEngine;
|
||||
@ -82,6 +93,12 @@ class TORCH_API DistAutogradContext {
|
||||
// Retrieve the GraphTask.
|
||||
std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask();
|
||||
|
||||
// An idempotent function for retrieving the GraphTask. Since notifying
|
||||
// neighbors of errors during autograd is recursive, it is possible that this
|
||||
// function is called multiple times. Thus, it should not crash if the
|
||||
// GraphTask has already been set with an error by a previous RPC.
|
||||
std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTaskIfExists();
|
||||
|
||||
// Set the appropriate graph task for the backward pass. Can be called only
|
||||
// once.
|
||||
void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask);
|
||||
|
@ -347,10 +347,21 @@ void DistEngine::execute(
|
||||
|
||||
BackwardPassCleanupGuard guard(autogradContext);
|
||||
|
||||
auto execFuture =
|
||||
runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges);
|
||||
// This callback propagates the dist autograd error to other nodes if it
|
||||
// encounters a failure.
|
||||
execFuture->addCallback(
|
||||
[autogradContext](
|
||||
const rpc::Message& /* unused */,
|
||||
const c10::optional<torch::utils::FutureError>& error) {
|
||||
if (error) {
|
||||
autogradContext->propagateAutogradError(error->what());
|
||||
}
|
||||
});
|
||||
// This needs to be blocking and as a result we wait for the future to
|
||||
// complete.
|
||||
runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
|
||||
->wait();
|
||||
execFuture->wait();
|
||||
|
||||
// Wait for all of the outstanding rpcs to complete.
|
||||
autogradContext->clearAndWaitForOutstandingRpcsAsync()->wait();
|
||||
|
@ -0,0 +1,66 @@
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
DistAutogradFailureReq::DistAutogradFailureReq(
|
||||
int64_t context_id,
|
||||
std::string errorMsg)
|
||||
: context_id_(context_id), errorMsg_(std::move(errorMsg)) {}
|
||||
|
||||
rpc::Message DistAutogradFailureReq::toMessage() && {
|
||||
std::vector<at::IValue> ivalues;
|
||||
// add context_id and errorMsg
|
||||
ivalues.emplace_back(context_id_);
|
||||
ivalues.emplace_back(errorMsg_);
|
||||
|
||||
// Now pickle using JIT pickler.
|
||||
std::vector<torch::Tensor> tensorTable;
|
||||
std::vector<char> payload =
|
||||
jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
|
||||
|
||||
return rpc::Message(
|
||||
std::move(payload),
|
||||
std::move(tensorTable),
|
||||
rpc::MessageType::DIST_AUTOGRAD_FAILURE_REQ);
|
||||
}
|
||||
|
||||
std::unique_ptr<DistAutogradFailureReq> DistAutogradFailureReq::fromMessage(
|
||||
const rpc::Message& message) {
|
||||
// Unpickle the message and retrieve tupleElements.
|
||||
auto payload = static_cast<const char*>(message.payload().data());
|
||||
auto payload_size = message.payload().size();
|
||||
IValue tuple = jit::unpickle(
|
||||
payload,
|
||||
payload_size,
|
||||
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
|
||||
&message.tensors());
|
||||
std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tupleElements.size() == 2);
|
||||
|
||||
// recover errorMsg
|
||||
std::string errorMsg = tupleElements.back().toString()->string();
|
||||
tupleElements.pop_back();
|
||||
|
||||
// recover context_id
|
||||
int64_t context_id = tupleElements.back().toInt();
|
||||
tupleElements.pop_back();
|
||||
|
||||
return std::make_unique<DistAutogradFailureReq>(context_id, errorMsg);
|
||||
}
|
||||
|
||||
int64_t DistAutogradFailureReq::getContextId() {
|
||||
return context_id_;
|
||||
}
|
||||
|
||||
std::string DistAutogradFailureReq::getErrorMsg() {
|
||||
return errorMsg_;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
// Used to notify other workers when there is an autograd error.
|
||||
class TORCH_API DistAutogradFailureReq : public rpc::RpcCommandBase {
|
||||
public:
|
||||
DistAutogradFailureReq(int64_t context_id, std::string errorMsg);
|
||||
// Serialization and deserialization methods.
|
||||
rpc::Message toMessage() && override;
|
||||
static std::unique_ptr<DistAutogradFailureReq> fromMessage(
|
||||
const rpc::Message& message);
|
||||
|
||||
int64_t getContextId();
|
||||
std::string getErrorMsg();
|
||||
|
||||
private:
|
||||
int64_t context_id_;
|
||||
std::string errorMsg_;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -0,0 +1,23 @@
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
rpc::Message DistAutogradFailureResp::toMessage() && {
|
||||
std::vector<torch::Tensor> tensors;
|
||||
std::vector<char> payload;
|
||||
return rpc::Message(
|
||||
std::move(payload),
|
||||
std::move(tensors),
|
||||
rpc::MessageType::DIST_AUTOGRAD_FAILURE_RESP);
|
||||
}
|
||||
|
||||
std::unique_ptr<DistAutogradFailureResp> DistAutogradFailureResp::fromMessage(
|
||||
const rpc::Message& message /* unused */) {
|
||||
return std::unique_ptr<DistAutogradFailureResp>();
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace autograd {
|
||||
|
||||
// Empty response for DistAutogradFailureReq. Send to acknowledge receipt of
|
||||
// a DistAutogradFailureReq.
|
||||
class TORCH_API DistAutogradFailureResp : public rpc::RpcCommandBase {
|
||||
public:
|
||||
DistAutogradFailureResp() = default;
|
||||
// Serialization and deserialization methods.
|
||||
rpc::Message toMessage() && override;
|
||||
static std::unique_ptr<DistAutogradFailureResp> fromMessage(
|
||||
const rpc::Message& message);
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -86,7 +86,9 @@ bool Message::isRequest() const {
|
||||
MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
|
||||
MessageType::FORWARD_AUTOGRAD_REQ == type_ ||
|
||||
// Cleanup Autograd context request
|
||||
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_;
|
||||
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ ||
|
||||
// Autograd Backward Error Notification request
|
||||
MessageType::DIST_AUTOGRAD_FAILURE_REQ == type_;
|
||||
}
|
||||
|
||||
bool Message::isResponse() const {
|
||||
@ -101,7 +103,9 @@ bool Message::isResponse() const {
|
||||
MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
|
||||
MessageType::FORWARD_AUTOGRAD_RESP == type_ ||
|
||||
// Cleanup autograd context response
|
||||
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_;
|
||||
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ ||
|
||||
// Autograd Backward Error Notification response
|
||||
MessageType::DIST_AUTOGRAD_FAILURE_RESP == type_;
|
||||
}
|
||||
|
||||
int64_t Message::id() const {
|
||||
|
@ -44,6 +44,9 @@ enum MessageType {
|
||||
CLEANUP_AUTOGRAD_CONTEXT_REQ = 19,
|
||||
CLEANUP_AUTOGRAD_CONTEXT_RESP = 20,
|
||||
|
||||
DIST_AUTOGRAD_FAILURE_REQ = 21,
|
||||
DIST_AUTOGRAD_FAILURE_RESP = 22,
|
||||
|
||||
// Other internal message types
|
||||
EXCEPTION = 55,
|
||||
UNKNOWN = 60
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||
@ -396,11 +398,22 @@ void RequestCallbackImpl::processRpc(
|
||||
case MessageType::BACKWARD_AUTOGRAD_REQ: {
|
||||
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
|
||||
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
|
||||
std::shared_ptr<DistAutogradContext> autogradContext;
|
||||
|
||||
// Retrieve the appropriate autograd context.
|
||||
auto autogradContext =
|
||||
DistAutogradContainer::getInstance().retrieveContext(
|
||||
// In rare cases, a BACKWARD_AUTOGRAD_REQ may arrive after the
|
||||
// context has been cleaned up due to an error during the backward pass.
|
||||
// In such situations, we can ignore this message since no further
|
||||
// gradient computations should take place on this context.
|
||||
autogradContext =
|
||||
DistAutogradContainer::getInstance().retrieveContextIfPresent(
|
||||
autogradMetadata.autogradContextId);
|
||||
if (!autogradContext) {
|
||||
LOG(INFO) << "Ignoring Backward Autograd Request. Context "
|
||||
<< autogradMetadata.autogradContextId
|
||||
<< " already cleaned up due to autograd error";
|
||||
markComplete(std::move(PropagateGradientsResp()).toMessage());
|
||||
return;
|
||||
}
|
||||
|
||||
// Lookup the appropriate 'send' function to enqueue.
|
||||
std::shared_ptr<SendRpcBackward> sendFunction =
|
||||
@ -441,6 +454,33 @@ void RequestCallbackImpl::processRpc(
|
||||
markComplete(std::move(CleanupAutogradContextResp()).toMessage());
|
||||
return;
|
||||
}
|
||||
case MessageType::DIST_AUTOGRAD_FAILURE_REQ: {
|
||||
auto& backwardFailureReq = static_cast<DistAutogradFailureReq&>(rpc);
|
||||
auto errorContextId = backwardFailureReq.getContextId();
|
||||
auto errorMsg = backwardFailureReq.getErrorMsg();
|
||||
// Mark the given context's graphTask with an error if the context has
|
||||
// not been cleaned up yet. Then use the given error message and
|
||||
// propagate it to all known neighbor workers.
|
||||
std::shared_ptr<DistAutogradContext> autogradContext;
|
||||
autogradContext =
|
||||
DistAutogradContainer::getInstance().retrieveContextIfPresent(
|
||||
errorContextId);
|
||||
if (!autogradContext) {
|
||||
LOG(INFO) << "Ignoring Dist Autograd Failure Request. Context "
|
||||
<< errorContextId
|
||||
<< " already cleaned up due to autograd error";
|
||||
markComplete(std::move(DistAutogradFailureResp()).toMessage());
|
||||
return;
|
||||
}
|
||||
// Mark the graphTask with an exception.
|
||||
bool graphTaskHadError = autogradContext->setGraphTaskException(errorMsg);
|
||||
// Propagate the dist autograd failure message to known workers.
|
||||
if (!graphTaskHadError) {
|
||||
autogradContext->propagateAutogradError(errorMsg);
|
||||
}
|
||||
markComplete(std::move(DistAutogradFailureResp()).toMessage());
|
||||
return;
|
||||
}
|
||||
default: {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Request type ", messageType, " not supported.");
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/dist_autograd_failure_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
|
||||
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
|
||||
@ -58,6 +60,9 @@ std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) {
|
||||
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
|
||||
return autograd::CleanupAutogradContextReq::fromMessage(request);
|
||||
}
|
||||
case MessageType::DIST_AUTOGRAD_FAILURE_REQ: {
|
||||
return autograd::DistAutogradFailureReq::fromMessage(request);
|
||||
}
|
||||
default: {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Request type ", request.type(), " not supported.");
|
||||
@ -109,6 +114,9 @@ std::unique_ptr<RpcCommandBase> deserializeResponse(
|
||||
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
|
||||
return autograd::CleanupAutogradContextResp::fromMessage(response);
|
||||
}
|
||||
case MessageType::DIST_AUTOGRAD_FAILURE_RESP: {
|
||||
return autograd::DistAutogradFailureResp::fromMessage(response);
|
||||
}
|
||||
default: {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "Response type ", response.type(), " not supported.");
|
||||
|
@ -1179,6 +1179,37 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
# Run backwards, and validate we receive an error.
|
||||
dist_autograd.backward(context_id, [val.sum()])
|
||||
|
||||
@dist_init
|
||||
def test_backward_intermediate_autograd_engine_error(self):
|
||||
with dist_autograd.context() as context_id:
|
||||
t1 = torch.rand((3, 3), requires_grad=True)
|
||||
t2 = torch.rand((3, 3), requires_grad=True)
|
||||
# Perform some ops before error simulation.
|
||||
tmp = (t1 + t2) * (t1 + t2)
|
||||
t3 = SimulateBackwardError.apply(tmp)
|
||||
|
||||
# Run multiple round trips across different nodes and verify the
|
||||
# original node receives an error thrown at some intermediate node
|
||||
# in the chain.
|
||||
val = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()), torch.add, args=(t2, t1)
|
||||
)
|
||||
val = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()), torch.mul, args=(val, t3)
|
||||
)
|
||||
val = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()), torch.matmul, args=(val, t2)
|
||||
)
|
||||
val = rpc.rpc_sync(
|
||||
"worker{}".format(self._next_rank()), torch.div, args=(val, t2)
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Run backwards, and validate we receive an error.
|
||||
dist_autograd.backward(context_id, [val.sum()])
|
||||
|
||||
self.assertTrue(_all_contexts_cleaned_up())
|
||||
|
||||
@dist_init(clean_shutdown=False)
|
||||
@unittest.skipIf(
|
||||
IS_MACOS,
|
||||
@ -1570,17 +1601,6 @@ class DistAutogradTest(RpcAgentTestFixture):
|
||||
):
|
||||
dist_autograd.backward(context_id, [t1.sum()])
|
||||
|
||||
# HACK: Killing workers since otherwise the autograd engine gets stuck on
|
||||
# other nodes. The proper fix would be addressing:
|
||||
# https://github.com/pytorch/pytorch/issues/27643, which would inform
|
||||
# other nodes about the failure.
|
||||
# The autograd engine gets stuck on other nodes since they're waiting to
|
||||
# receive gradients from the node that received an error (and as a
|
||||
# result it didn't execute the rest of the graph).
|
||||
dist.barrier()
|
||||
rpc.shutdown(graceful=False)
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
def _call_remote_embedding(cls, embedding_rref, input, offsets, per_sample_weights):
|
||||
embedding = embedding_rref.local_value()
|
||||
|
Reference in New Issue
Block a user