mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Check all `.cpp` files except `jit` files for readability thoroughly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164561 Approved by: https://github.com/Skylion007
327 lines
10 KiB
C++
327 lines
10 KiB
C++
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
|
|
|
|
namespace torch::distributed::autograd {
|
|
|
|
constexpr int kAutoIncrementBits = 48;
|
|
constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1;
|
|
constexpr int kMaxWorkerId = 65535;
|
|
constexpr int kNumCleanupContextRetries = 20;
|
|
|
|
constexpr int64_t kInvalidContextId = -1;
|
|
|
|
// Each thread has a single autograd_context_id valid at any point in time.
|
|
static thread_local int64_t current_context_id_ = kInvalidContextId;
|
|
|
|
// Lock to ensure DistAutogradContainer is initialized only once.
|
|
static std::mutex dist_container_init_lock_;
|
|
|
|
DistAutogradContainer::DistAutogradContainer(uint32_t num_shards)
|
|
: next_context_id_(0),
|
|
worker_id_(0),
|
|
initialized_(false),
|
|
autograd_contexts_(num_shards),
|
|
num_shards_(num_shards),
|
|
next_autograd_message_id_(0),
|
|
max_id_(0) {
|
|
// num_shards has to be a power of 2 for the modulo trick in 'getShard'
|
|
// to work.
|
|
TORCH_INTERNAL_ASSERT((num_shards & (num_shards - 1)) == 0);
|
|
}
|
|
|
|
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
|
|
std::lock_guard<std::mutex> guard(dist_container_init_lock_);
|
|
|
|
TORCH_CHECK(
|
|
worker_id >= 0 && worker_id <= kMaxWorkerId,
|
|
"worker_id needs to be in the range [0, 65535]")
|
|
|
|
auto& container = getInstanceInternal();
|
|
TORCH_CHECK(
|
|
!container.initialized_ || (worker_id == container.worker_id_),
|
|
"Container is already initialized with worker_id: ",
|
|
container.worker_id_,
|
|
", cannot initialize with different worker_id: ",
|
|
worker_id);
|
|
|
|
if (container.initialized_) {
|
|
LOG(INFO) << "DistAutogradContainer is already initialized";
|
|
return container;
|
|
}
|
|
|
|
container.worker_id_ = static_cast<int16_t>(worker_id);
|
|
container.next_context_id_ = worker_id << kAutoIncrementBits;
|
|
container.next_autograd_message_id_ = worker_id << kAutoIncrementBits;
|
|
container.max_id_ = (kAutoIncrementMask | (worker_id << kAutoIncrementBits));
|
|
container.initialized_ = true;
|
|
return container;
|
|
}
|
|
|
|
uint32_t DistAutogradContainer::computeNumShards() {
|
|
uint32_t num_shards = 1;
|
|
auto num_hw_threads = std::thread::hardware_concurrency();
|
|
if (num_hw_threads == 0) {
|
|
num_shards = kNumDefaultShards;
|
|
} else {
|
|
// Compute the next power of 2 which is higher than twice the hardware
|
|
// concurrency.
|
|
while (num_shards < num_hw_threads * 2) {
|
|
num_shards <<= 1;
|
|
}
|
|
}
|
|
VLOG(1) << "Number of shards for DistAutogradContainer: " << num_shards;
|
|
return num_shards;
|
|
}
|
|
|
|
inline DistAutogradContainer::ContextsShard& DistAutogradContainer::getShard(
|
|
int64_t context_id) {
|
|
// num_shards_ has to be a power of 2 for this modulo trick to work (validated
|
|
// during init).
|
|
return autograd_contexts_[context_id & (num_shards_ - 1)];
|
|
}
|
|
|
|
DistAutogradContainer& DistAutogradContainer::getInstance() {
|
|
auto& instance = getInstanceInternal();
|
|
TORCH_CHECK(
|
|
instance.initialized_,
|
|
"Need to initialize distributed autograd using "
|
|
"torch.distributed.autograd.init()");
|
|
return instance;
|
|
}
|
|
|
|
DistAutogradContainer& DistAutogradContainer::getInstanceInternal() {
|
|
// Leaky singleton to avoid module destructor race.
|
|
static DistAutogradContainer* container =
|
|
new DistAutogradContainer(computeNumShards());
|
|
return *container;
|
|
}
|
|
|
|
int64_t DistAutogradContainer::newAutogradMessageId() {
|
|
// Check for overflow into workerId_ section.
|
|
TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
|
|
return next_autograd_message_id_++;
|
|
}
|
|
|
|
ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
|
|
auto& shard = getShard(context_id);
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
auto it = shard.contexts.find(context_id);
|
|
if (it != shard.contexts.end()) {
|
|
return it->second;
|
|
}
|
|
|
|
auto& context =
|
|
shard.contexts
|
|
.emplace(
|
|
std::piecewise_construct,
|
|
std::forward_as_tuple(context_id),
|
|
std::forward_as_tuple(
|
|
std::make_shared<DistAutogradContext>(context_id)))
|
|
.first->second;
|
|
return context;
|
|
}
|
|
|
|
rpc::worker_id_t DistAutogradContainer::getWorkerId() const {
|
|
return worker_id_;
|
|
}
|
|
|
|
const ContextPtr DistAutogradContainer::newContext() {
|
|
TORCH_CHECK(
|
|
current_context_id_ == kInvalidContextId,
|
|
"Already have an autograd context id for this thread.");
|
|
|
|
auto context_id = next_context_id_++;
|
|
current_context_id_ = context_id;
|
|
|
|
// Check for overflow into workerId_ section.
|
|
TORCH_INTERNAL_ASSERT(context_id < max_id_);
|
|
|
|
auto& shard = getShard(context_id);
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
auto& context =
|
|
shard.contexts
|
|
.emplace(
|
|
std::piecewise_construct,
|
|
std::forward_as_tuple(context_id),
|
|
std::forward_as_tuple(
|
|
std::make_shared<DistAutogradContext>(context_id)))
|
|
.first->second;
|
|
|
|
return context;
|
|
}
|
|
|
|
bool DistAutogradContainer::hasValidContext() const {
|
|
return current_context_id_ != kInvalidContextId;
|
|
}
|
|
|
|
ContextPtr DistAutogradContainer::currentContext() {
|
|
TORCH_CHECK(
|
|
hasValidContext(),
|
|
"Current thread doesn't have a valid autograd context. Please wrap your "
|
|
"code using: `with torch.distributed.autograd.context() as context_id` "
|
|
"to generate a valid context");
|
|
|
|
auto& shard = getShard(current_context_id_);
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
auto it = shard.contexts.find(current_context_id_);
|
|
TORCH_CHECK(
|
|
it != shard.contexts.end(),
|
|
"Couldn't find autograd context "
|
|
"data for current autograd context id");
|
|
return it->second;
|
|
}
|
|
|
|
void DistAutogradContainer::releaseContextIfPresent(int64_t context_id) {
|
|
auto& shard = getShard(context_id);
|
|
std::unique_lock<std::mutex> lock(shard.lock);
|
|
auto it = shard.contexts.find(context_id);
|
|
|
|
// no-op if the context does not exist on this thread. This could happen if an
|
|
// in-flight RPC has already released the context on this thread.
|
|
if (it == shard.contexts.end()) {
|
|
return;
|
|
}
|
|
|
|
auto knownWorkerIds = it->second->getKnownWorkerIds();
|
|
eraseContextIdAndReset(shard, context_id);
|
|
|
|
// Unlock since we no longer need the lock.
|
|
lock.unlock();
|
|
sendReleaseContextRpc(knownWorkerIds, context_id);
|
|
}
|
|
|
|
void DistAutogradContainer::releaseContext(int64_t context_id) {
|
|
auto& shard = getShard(context_id);
|
|
std::unique_lock<std::mutex> lock(shard.lock);
|
|
auto it = shard.contexts.find(context_id);
|
|
|
|
TORCH_CHECK(
|
|
it != shard.contexts.end(),
|
|
"Could not find autograd context with id: ",
|
|
context_id);
|
|
|
|
auto knownWorkerIds = it->second->getKnownWorkerIds();
|
|
eraseContextIdAndReset(shard, context_id);
|
|
|
|
// Unlock since we no longer need the lock.
|
|
lock.unlock();
|
|
sendReleaseContextRpc(knownWorkerIds, context_id);
|
|
}
|
|
|
|
void DistAutogradContainer::sendReleaseContextRpc(
|
|
const std::unordered_set<rpc::worker_id_t>& workerIds,
|
|
int64_t context_id) {
|
|
// Best-effort notification to other workers to clean up their Dist autograd
|
|
// context, in order to reduce memory usage.
|
|
// agent.send() or getCurrentRpcAgent may throw an error in the case of an
|
|
// ungraceful shutdown, where we are shutting down RPC and also processing
|
|
// this message in a separate thread concurrently. In this case, don't throw
|
|
// here.
|
|
std::shared_ptr<rpc::RpcAgent> agent;
|
|
try {
|
|
agent = rpc::RpcAgent::getCurrentRpcAgent();
|
|
} catch (const std::exception& e) {
|
|
LOG(INFO)
|
|
<< "Failed to send RPC to clear Dist Autograd context to all workers: "
|
|
<< e.what();
|
|
return;
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(agent, "RPC Agent should be set.");
|
|
|
|
rpc::RpcRetryOptions options;
|
|
options.maxRetries = kNumCleanupContextRetries;
|
|
for (const auto& worker_id : workerIds) {
|
|
try {
|
|
auto cleanupFuture = agent->sendWithRetries(
|
|
agent->getWorkerInfo(worker_id),
|
|
CleanupAutogradContextReq(context_id).toMessage(),
|
|
options);
|
|
|
|
cleanupFuture->addCallback([worker_id](rpc::JitFuture& future) {
|
|
if (future.hasError()) {
|
|
std::string errorMsg = c10::str(
|
|
"Could not release Dist Autograd Context on node ",
|
|
worker_id,
|
|
": ",
|
|
future.tryRetrieveErrorMessage());
|
|
LOG(ERROR) << errorMsg;
|
|
return;
|
|
}
|
|
});
|
|
} catch (const std::exception& e) {
|
|
LOG(INFO)
|
|
<< "Failed to send RPC to clear Dist Autograd context to worker id: "
|
|
<< worker_id << " : " << e.what();
|
|
}
|
|
}
|
|
}
|
|
|
|
void DistAutogradContainer::eraseContextIdAndReset(
|
|
DistAutogradContainer::ContextsShard& shard,
|
|
int64_t context_id) {
|
|
// We already have the shard lock here.
|
|
shard.contexts.erase(context_id);
|
|
|
|
if (current_context_id_ == context_id) {
|
|
// Reset the thread_local current context id, since it is no longer valid.
|
|
current_context_id_ = kInvalidContextId;
|
|
}
|
|
}
|
|
|
|
void DistAutogradContainer::isValidContext(int64_t context_id) {
|
|
auto& shard = getShard(context_id);
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
TORCH_CHECK(
|
|
shard.contexts.find(context_id) != shard.contexts.end(),
|
|
"Could not find autograd context with id: ",
|
|
context_id);
|
|
}
|
|
|
|
ContextPtr DistAutogradContainer::retrieveContext(int64_t context_id) {
|
|
auto& shard = getShard(context_id);
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
auto it = shard.contexts.find(context_id);
|
|
TORCH_CHECK(
|
|
it != shard.contexts.end(),
|
|
"Could not find autograd context with id: ",
|
|
context_id);
|
|
return it->second;
|
|
}
|
|
|
|
int64_t DistAutogradContainer::getMaxId() {
|
|
return max_id_;
|
|
}
|
|
|
|
void DistAutogradContainer::forceCurrentContextId(int64_t contextId) {
|
|
current_context_id_ = contextId;
|
|
}
|
|
|
|
void DistAutogradContainer::setCurrentContextId(int64_t contextId) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
current_context_id_ == kInvalidContextId,
|
|
"Already have an autograd context id for this thread.");
|
|
current_context_id_ = contextId;
|
|
}
|
|
|
|
void DistAutogradContainer::clearCurrentContext() {
|
|
current_context_id_ = -1;
|
|
}
|
|
|
|
size_t DistAutogradContainer::numAutogradContexts() const {
|
|
size_t ret = 0;
|
|
for (const auto& shard : autograd_contexts_) {
|
|
std::lock_guard<std::mutex> guard(shard.lock);
|
|
ret += shard.contexts.size();
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
int64_t DistAutogradContainer::currentContextId() {
|
|
return current_context_id_;
|
|
}
|
|
|
|
} // namespace torch::distributed::autograd
|