mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Use atomic operations to manipulate current RPC agent (#39663)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/39663 I was investigating a memory corruption issue and thought it may be due to a race condition in (un)setting the current RPC agent. It turns out it wasn't (still investigating...). I had already written this fix, and it is a real fix (there could really be a race condition), so I'm sending it out to see whether there's interest in merging it. I believe its practical usefulness is however very limited, since typically the current RPC agent is only changed twice (at start and at shutdown) and thus there's limited risk for races. As there may be some confusion on atomicity of shared_ptrs, let me clarify a few things from the get go. Operations on the control blocks of shared_ptrs (i.e., increasing and decreasing the refcounts) are atomic, which means that it is safe to manipulate *two different* shared_ptrs that point to the *same* object from *different* threads. However, the shared_ptr object itself is not atomic, which means that it is *not* safe to manipulate the *same* shared_ptr from two *different* threads. For that reason, the STL provides atomic functions explicitly specialized for shared_ptrs: https://en.cppreference.com/w/cpp/memory/shared_ptr/atomic (in C++ 20, they are being replaced by a specialization of std::atomic<std::shared_ptr<T>>). Note that this has been called "the worst question of all of C++" by Louis Brandy at his CppCon talk: https://youtu.be/lkgszkPnV8g?t=1210 ghstack-source-id: 105475005 Test Plan: Unit tests Differential Revision: D21932817 fbshipit-source-id: da33fedd98efb820f284583ce7ff1c1c531dea9c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
af05158c56
commit
7d85e77076
@ -239,21 +239,34 @@ const WorkerInfo& RpcAgent::getWorkerInfo() const {
|
||||
std::shared_ptr<RpcAgent> RpcAgent::currentRpcAgent_ = nullptr;
|
||||
|
||||
bool RpcAgent::isCurrentRpcAgentSet() {
|
||||
return currentRpcAgent_ != nullptr;
|
||||
return std::atomic_load(¤tRpcAgent_) != nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<RpcAgent> RpcAgent::getCurrentRpcAgent() {
|
||||
TORCH_INTERNAL_ASSERT(currentRpcAgent_, "Current RPC agent is not set!");
|
||||
return currentRpcAgent_;
|
||||
std::shared_ptr<RpcAgent> agent = std::atomic_load(¤tRpcAgent_);
|
||||
TORCH_INTERNAL_ASSERT(agent, "Current RPC agent is not set!");
|
||||
return agent;
|
||||
}
|
||||
|
||||
void RpcAgent::setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent) {
|
||||
if (rpcAgent) {
|
||||
TORCH_INTERNAL_ASSERT(!currentRpcAgent_, "Current RPC agent is set!");
|
||||
std::shared_ptr<RpcAgent> previousAgent;
|
||||
// Use compare_exchange so that we don't actually perform the exchange if
|
||||
// that would trigger the assert just below. See:
|
||||
// https://en.cppreference.com/w/cpp/atomic/atomic_compare_exchange
|
||||
std::atomic_compare_exchange_strong(
|
||||
¤tRpcAgent_, &previousAgent, std::move(rpcAgent));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
previousAgent == nullptr, "Current RPC agent is set!");
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(currentRpcAgent_, "Current RPC agent is not set!");
|
||||
// We can't use compare_exchange (we don't know what value to expect) but we
|
||||
// don't need to, as the only case that would trigger the assert is if we
|
||||
// replaced nullptr with nullptr, which we can just do as it has no effect.
|
||||
std::shared_ptr<RpcAgent> previousAgent =
|
||||
std::atomic_exchange(¤tRpcAgent_, std::move(rpcAgent));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
previousAgent != nullptr, "Current RPC agent is not set!");
|
||||
}
|
||||
currentRpcAgent_ = std::move(rpcAgent);
|
||||
}
|
||||
|
||||
void RpcAgent::setTypeResolver(std::shared_ptr<TypeResolver> typeResolver) {
|
||||
|
Reference in New Issue
Block a user