Files
pytorch/torch/csrc/distributed/rpc/agent_utils.cpp
Shen Li c7b1979b6b Use Store collect and verify names in all RPC agents (#53209)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53209

closes #40048

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D26791524

Pulled By: mrshenli

fbshipit-source-id: fc75589f9707014334fcfae6f05af3c04217783b
2021-03-07 16:51:46 -08:00

46 lines
1.3 KiB
C++

#include <torch/csrc/distributed/rpc/agent_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unordered_map<std::string, worker_id_t> collectNames(
::c10d::PrefixStore store,
const worker_id_t selfId,
const std::string& selfName,
const int worldSize) {
std::vector<uint8_t> selfNameVector(
(uint8_t*)selfName.c_str(),
(uint8_t*)selfName.c_str() + selfName.length());
store.set(c10::to_string(selfId), selfNameVector);
std::unordered_map<std::string, worker_id_t> nameToId;
nameToId.reserve(worldSize);
nameToId.emplace(selfName, selfId);
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
if (workerId == selfId) {
continue;
}
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
std::string workerName(
(char*)workerNameVector.data(), workerNameVector.size());
TORCH_CHECK(
nameToId.find(workerName) == nameToId.end(),
"RPC worker name ",
workerName,
" is not unique. Workers ",
nameToId.find(workerName)->second,
" and ",
workerId,
" share the same name.");
nameToId.emplace(workerName, workerId);
}
return nameToId;
}
} // namespace rpc
} // namespace distributed
} // namespace torch