Files
pytorch/torch/csrc/distributed/rpc/message.cpp
Omkar Salpekar 5f67c923f1 [1.5 Release][Dist Autograd][Better Engineering] Notify Workers on Failure during Distributed Autograd (#34638)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638

Fixes: https://github.com/pytorch/pytorch/issues/27643

This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly.

(Note: this ignores all push blocking failures!)

Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass.

Differential Revision: D20164420

fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
2020-03-18 18:56:14 -07:00

135 lines
4.0 KiB
C++

#include <torch/csrc/distributed/rpc/message.h>
namespace torch {
namespace distributed {
namespace rpc {
Message::Message() = default;
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type)
: payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id)
: payload_(std::move(payload)),
tensors_(std::move(tensors)),
type_(type),
id_(id) {}
Message::Message(const Message& other) = default;
Message::Message(Message&& other) noexcept = default;
Message& Message::operator=(Message const& rhs) & {
auto payload = rhs.payload_;
auto tensors = rhs.tensors_;
Message(std::move(payload), std::move(tensors), rhs.type_, rhs.id_)
.swap(*this);
return *this;
}
Message& Message::operator=(Message&& rhs) & {
Message(std::move(rhs.payload_), std::move(rhs.tensors_), rhs.type_, rhs.id_)
.swap(*this);
return *this;
}
void Message::swap(Message& rhs) noexcept {
std::swap(payload_, rhs.payload_);
std::swap(tensors_, rhs.tensors_);
std::swap(type_, rhs.type_);
std::swap(id_, rhs.id_);
}
std::vector<char>&& Message::movePayload() && {
return std::move(payload_);
}
const std::vector<char>& Message::payload() const {
return payload_;
}
std::vector<torch::Tensor>&& Message::moveTensors() && {
return std::move(tensors_);
}
std::vector<torch::Tensor>& Message::tensors() {
return tensors_;
}
const std::vector<torch::Tensor>& Message::tensors() const {
return tensors_;
}
MessageType Message::type() const {
return type_;
}
bool Message::isRequest() const {
return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops
MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs
MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops
MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs
// RRef related internal messages
MessageType::SCRIPT_RREF_FETCH_CALL == type_ ||
MessageType::PYTHON_RREF_FETCH_CALL == type_ ||
MessageType::RREF_USER_DELETE == type_ ||
MessageType::RREF_CHILD_ACCEPT == type_ ||
MessageType::RREF_FORK_REQUEST == type_ ||
// Autograd message
MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
MessageType::FORWARD_AUTOGRAD_REQ == type_ ||
// Cleanup Autograd context request
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ ||
// Autograd Backward Error Notification request
MessageType::DIST_AUTOGRAD_FAILURE_REQ == type_;
}
bool Message::isResponse() const {
return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops
MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs
MessageType::REMOTE_RET == type_ || // ret of dist.remote
MessageType::SCRIPT_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
MessageType::PYTHON_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
MessageType::EXCEPTION == type_ || // propagate back exceptions
MessageType::RREF_ACK == type_ || // ret of other types
// Autograd response
MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
MessageType::FORWARD_AUTOGRAD_RESP == type_ ||
// Cleanup autograd context response
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ ||
// Autograd Backward Error Notification response
MessageType::DIST_AUTOGRAD_FAILURE_RESP == type_;
}
int64_t Message::id() const {
return id_;
}
void Message::setId(int64_t id) {
id_ = id;
}
Message createExceptionResponse(const std::exception& e, int64_t id) {
return createExceptionResponse(e.what(), id);
}
Message createExceptionResponse(const std::string& exceptionStr, int64_t id) {
std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
return Message(
std::move(payload),
std::vector<torch::Tensor>(),
MessageType::EXCEPTION,
id);
}
} // namespace rpc
} // namespace distributed
} // namespace torch