Files
pytorch/torch/csrc/distributed/autograd/context/container.cpp
Omkar Salpekar 5f67c923f1 [1.5 Release][Dist Autograd][Better Engineering] Notify Workers on Failure during Distributed Autograd (#34638)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638

Fixes: https://github.com/pytorch/pytorch/issues/27643

This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly.

(Note: this ignores all push blocking failures!)

Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass.

Differential Revision: D20164420

fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
2020-03-18 18:56:14 -07:00

240 lines
7.7 KiB
C++

#include <torch/csrc/distributed/autograd/context/container.h>
#include <c10/util/Exception.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
namespace torch {
namespace distributed {
namespace autograd {
constexpr int kAutoIncrementBits = 48;
constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1;
constexpr int kMaxWorkerId = 65535;
constexpr int64_t kInvalidContextId = -1;
// Each thread has a single autograd_context_id valid at any point in time.
static thread_local int64_t current_context_id_ = kInvalidContextId;
// Lock to ensure DistAutogradContainer is initialized only once.
static std::mutex dist_container_init_lock_;
DistAutogradContainer::DistAutogradContainer()
: next_context_id_(0),
worker_id_(0),
initialized_(false),
next_autograd_message_id_(0),
max_id_(0) {}
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
std::lock_guard<std::mutex> guard(dist_container_init_lock_);
TORCH_CHECK(
worker_id >= 0 && worker_id <= kMaxWorkerId,
"worker_id needs to be in the range [0, 65535]")
auto& container = getInstanceInternal();
TORCH_CHECK(
!container.initialized_,
"Container is already initialized! Cannot initialize it twice!");
container.worker_id_ = worker_id;
container.next_context_id_ = static_cast<int64_t>(worker_id)
<< kAutoIncrementBits;
container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
<< kAutoIncrementBits;
container.max_id_ =
(kAutoIncrementMask |
(static_cast<int64_t>(worker_id) << kAutoIncrementBits));
container.initialized_ = true;
return container;
}
DistAutogradContainer& DistAutogradContainer::getInstance() {
auto& instance = getInstanceInternal();
TORCH_CHECK(
instance.initialized_,
"Need to initialize distributed autograd using "
"torch.distributed.autograd.init()");
return instance;
}
DistAutogradContainer& DistAutogradContainer::getInstanceInternal() {
// Leaky singleton to avoid module destructor race.
static DistAutogradContainer* container = new DistAutogradContainer();
return *container;
}
int64_t DistAutogradContainer::newAutogradMessageId() {
// Check for overflow into workerId_ section.
TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
return next_autograd_message_id_++;
}
ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
auto it = autograd_context_.find(context_id);
if (it != autograd_context_.end()) {
return it->second;
}
auto& context =
autograd_context_
.emplace(
std::piecewise_construct,
std::forward_as_tuple(context_id),
std::forward_as_tuple(
std::make_shared<DistAutogradContext>(context_id)))
.first->second;
return context;
}
rpc::worker_id_t DistAutogradContainer::getWorkerId() const {
return worker_id_;
}
const ContextPtr DistAutogradContainer::newContext() {
TORCH_CHECK(
current_context_id_ == kInvalidContextId,
"Already have an autograd context id for this thread.");
std::lock_guard<std::mutex> guard(autograd_context_lock_);
// Check for overflow into workerId_ section.
TORCH_INTERNAL_ASSERT(next_context_id_ < max_id_);
auto& context =
autograd_context_
.emplace(
std::piecewise_construct,
std::forward_as_tuple(next_context_id_),
std::forward_as_tuple(
std::make_shared<DistAutogradContext>(next_context_id_)))
.first->second;
current_context_id_ = next_context_id_++;
return context;
}
bool DistAutogradContainer::hasValidContext() const {
return current_context_id_ != kInvalidContextId;
}
ContextPtr DistAutogradContainer::currentContext() {
TORCH_CHECK(
hasValidContext(),
"Current thread doesn't have a valid autograd context. Please wrap your "
"code using: `with torch.distributed.autograd.context() as context_id` "
"to generate a valid context");
std::lock_guard<std::mutex> guard(autograd_context_lock_);
auto it = autograd_context_.find(current_context_id_);
TORCH_CHECK(
it != autograd_context_.end(),
"Couldn't find autograd context "
"data for current autograd context id");
return it->second;
}
void DistAutogradContainer::releaseContextIfPresent(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
// no-op if the context does not exist on this thread. This could happen if an
// in-flight RPC has already released the context on this thread.
if (autograd_context_.find(context_id) == autograd_context_.end()) {
return;
}
sendReleaseContextRpc(context_id);
eraseContextIdAndReset(context_id);
}
void DistAutogradContainer::releaseContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(
autograd_context_.find(context_id) != autograd_context_.end(),
"Could not find autograd context with id: ",
context_id);
sendReleaseContextRpc(context_id);
eraseContextIdAndReset(context_id);
}
void DistAutogradContainer::sendReleaseContextRpc(int64_t context_id) {
// notify other workers to clean up their contexts.
auto workerIds =
autograd_context_.find(context_id)->second->getKnownWorkerIds();
// agent.send() or getCurrentRpcAgent may throw an error in the case of an
// ungraceful shutdown, where we are shutting down RPC and also processing
// this message in a separate thread concurrently. In this case, don't throw
// here.
try {
auto agent = rpc::RpcAgent::getCurrentRpcAgent();
for (const auto& worker_id : workerIds) {
agent->send(
agent->getWorkerInfo(worker_id),
CleanupAutogradContextReq(context_id).toMessage());
}
} catch (const std::exception& e) {
LOG(INFO)
<< "Failed to send RPC to clear Dist Autograd context to some nodes: "
<< e.what();
}
}
void DistAutogradContainer::eraseContextIdAndReset(int64_t context_id) {
autograd_context_.erase(context_id);
if (current_context_id_ == context_id) {
// Reset the thread_local current context id, since it is no longer valid.
current_context_id_ = kInvalidContextId;
}
}
void DistAutogradContainer::isValidContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(
autograd_context_.find(context_id) != autograd_context_.end(),
"Could not find autograd context with id: ",
context_id);
}
ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
TORCH_CHECK(
autograd_context_.find(context_id) != autograd_context_.end(),
"Could not find autograd context with id: ",
context_id);
return autograd_context_.at(context_id);
}
ContextPtr DistAutogradContainer::retrieveContextIfPresent(int64_t context_id) {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
auto it = autograd_context_.find(context_id);
if (it != autograd_context_.end()) {
return it->second;
} else {
return nullptr;
}
}
int64_t DistAutogradContainer::getMaxId() {
return max_id_;
}
void DistAutogradContainer::setCurrentContextId(int64_t contextId) {
TORCH_INTERNAL_ASSERT(
current_context_id_ == kInvalidContextId,
"Already have an autograd context id for this thread.");
current_context_id_ = contextId;
}
void DistAutogradContainer::clearCurrentContext() {
current_context_id_ = -1;
}
size_t DistAutogradContainer::numAutogradContexts() const {
std::lock_guard<std::mutex> guard(autograd_context_lock_);
return autograd_context_.size();
}
} // namespace autograd
} // namespace distributed
} // namespace torch