mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53209 closes #40048 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26791524 Pulled By: mrshenli fbshipit-source-id: fc75589f9707014334fcfae6f05af3c04217783b
46 lines
1.3 KiB
C++
46 lines
1.3 KiB
C++
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
std::unordered_map<std::string, worker_id_t> collectNames(
|
|
::c10d::PrefixStore store,
|
|
const worker_id_t selfId,
|
|
const std::string& selfName,
|
|
const int worldSize) {
|
|
std::vector<uint8_t> selfNameVector(
|
|
(uint8_t*)selfName.c_str(),
|
|
(uint8_t*)selfName.c_str() + selfName.length());
|
|
store.set(c10::to_string(selfId), selfNameVector);
|
|
|
|
std::unordered_map<std::string, worker_id_t> nameToId;
|
|
nameToId.reserve(worldSize);
|
|
nameToId.emplace(selfName, selfId);
|
|
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
|
|
if (workerId == selfId) {
|
|
continue;
|
|
}
|
|
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
|
|
std::string workerName(
|
|
(char*)workerNameVector.data(), workerNameVector.size());
|
|
|
|
TORCH_CHECK(
|
|
nameToId.find(workerName) == nameToId.end(),
|
|
"RPC worker name ",
|
|
workerName,
|
|
" is not unique. Workers ",
|
|
nameToId.find(workerName)->second,
|
|
" and ",
|
|
workerId,
|
|
" share the same name.");
|
|
|
|
nameToId.emplace(workerName, workerId);
|
|
}
|
|
return nameToId;
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|