Files
pytorch/torch/csrc/distributed/rpc/agent_utils.cpp
Howard Huang 938afa37a3 Remove process group barrier and all_reduce function calls from tensorpipe agent (#65946)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65946

Add new function in agent_utils to perform a synchronization of active call counts using store. This is intended to replace the barrier and all_reduce used by the process group in RPC shutdown.

`test_ddp_comparison` and `test_ddp_comparison_uneven_inputs` test fail with these changes. It seems like the RPC agents are not accessing the same store, so the total count of processes never reaches the world size to exit the barrier, still ened to investigate why it is like this only for these test cases. Setting clean_shutdown to false ignores this code path which allows the test to pass.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D31762736

Pulled By: H-Huang

fbshipit-source-id: cb5d0efe196f72726c63393c4293e97ec4f18548
2021-10-28 10:15:56 -07:00

95 lines
3.1 KiB
C++

#include <torch/csrc/distributed/rpc/agent_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unordered_map<std::string, worker_id_t> collectNames(
::c10d::PrefixStore store,
const worker_id_t selfId,
const std::string& selfName,
const int worldSize) {
std::vector<uint8_t> selfNameVector(
(uint8_t*)selfName.c_str(),
(uint8_t*)selfName.c_str() + selfName.length());
store.set(c10::to_string(selfId), selfNameVector);
std::unordered_map<std::string, worker_id_t> nameToId;
nameToId.reserve(worldSize);
nameToId.emplace(selfName, selfId);
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
if (workerId == selfId) {
continue;
}
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
std::string workerName(
(char*)workerNameVector.data(), workerNameVector.size());
TORCH_CHECK(
nameToId.find(workerName) == nameToId.end(),
"RPC worker name ",
workerName,
" is not unique. Workers ",
nameToId.find(workerName)->second,
" and ",
workerId,
" share the same name.");
nameToId.emplace(workerName, workerId);
}
return nameToId;
}
const string storeKeyBarrierId = "_ID_";
const string storeKeyProcessCount = "PROCESS_COUNT";
const string storeKeyActiveCallCount = "ACTIVE_CALLS";
const string storeKeyReady = "READY";
static std::atomic<int> barrierId(0);
std::tuple<std::string, std::string, std::string> getNextKeyIds() {
barrierId++;
std::string processCountKey =
storeKeyProcessCount + storeKeyBarrierId + std::to_string(barrierId);
std::string activeCallCountKey =
storeKeyActiveCallCount + storeKeyBarrierId + std::to_string(barrierId);
std::string barrierKey =
storeKeyReady + storeKeyBarrierId + std::to_string(barrierId);
return std::make_tuple(processCountKey, activeCallCountKey, barrierKey);
}
// Synchronize process with all other agent processes strictly using store
// Block until all ``RpcAgent``s reach this method.
// Returns total number of active calls of all RPC agents in the group
int syncCallCount(
::c10d::PrefixStore store,
const int worldSize,
int activeCalls) {
std::string processCountKey, activeCallCountKey, readyKey;
std::tie(processCountKey, activeCallCountKey, readyKey) = getNextKeyIds();
// Add to keys which will record the number of processes and active calls
int totalCallCount = store.add(activeCallCountKey, activeCalls);
int totalProcessCount = store.add(processCountKey, 1);
VLOG(1) << processCountKey << " " << totalCallCount << " "
<< totalProcessCount;
// The last worker will need to set the ready key
if (totalProcessCount == worldSize) {
store.set(readyKey, std::vector<uint8_t>());
}
// Wait on the ready key to be set
store.wait(std::vector<std::string>{readyKey});
// Read count of active calls which may have changed
auto activeCallCountData = store.get(activeCallCountKey);
totalCallCount = std::stoi(
std::string(activeCallCountData.begin(), activeCallCountData.end()));
return totalCallCount;
}
} // namespace rpc
} // namespace distributed
} // namespace torch