Use Store collect and verify names in all RPC agents (#53209)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53209

closes #40048

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D26791524

Pulled By: mrshenli

fbshipit-source-id: fc75589f9707014334fcfae6f05af3c04217783b
This commit is contained in:
Shen Li
2021-03-07 16:46:13 -08:00
committed by Facebook GitHub Bot
parent affdcce833
commit c7b1979b6b
15 changed files with 74 additions and 64 deletions

View File

@ -353,6 +353,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
# up being empty. Downstream targets should also add a #ifdef guard.
if(NOT WIN32)
add_library(process_group_agent
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h"
)
@ -361,6 +363,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
if(USE_TENSORPIPE)
add_library(tensorpipe_agent
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"

View File

@ -26,6 +26,7 @@ class TestE2EProcessGroup : public TestE2EBase {
store, 0, numWorkers, options);
rpcAgent = std::make_shared<ProcessGroupAgent>(
store,
"worker",
pg,
std::max(16U, std::thread::hardware_concurrency()),

View File

@ -611,6 +611,7 @@ libtorch_python_distributed_core_sources = [
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
"torch/csrc/distributed/autograd/init.cpp",
"torch/csrc/distributed/rpc/agent_utils.cpp",
"torch/csrc/distributed/rpc/init.cpp",
"torch/csrc/distributed/rpc/process_group_agent.cpp",
"torch/csrc/distributed/rpc/py_rref.cpp",

View File

@ -77,6 +77,7 @@ class ProcessGroupRpcBackendOptions(RpcBackendOptions):
class ProcessGroupAgent(RpcAgent):
def __init__(
self,
store: Store,
worker_name: str,
pg: ProcessGroup,
numSendRecvThreads: int,

View File

@ -1,4 +1,4 @@
from ._distributed_c10d import ProcessGroup
from ._distributed_c10d import ProcessGroup, Store
from ._distributed_rpc import ProcessGroupAgent, ProcessGroupRpcBackendOptions, WorkerInfo
from typing import List, Dict, overload
from datetime import timedelta
@ -23,6 +23,7 @@ class FaultyProcessGroupRpcBackendOptions(ProcessGroupRpcBackendOptions):
class FaultyProcessGroupAgent(ProcessGroupAgent):
def __init__(
self,
store: Store,
name: str,
process_group: ProcessGroup,
num_send_recv_threads: int,

View File

@ -0,0 +1,45 @@
#include <torch/csrc/distributed/rpc/agent_utils.h>
namespace torch {
namespace distributed {
namespace rpc {
std::unordered_map<std::string, worker_id_t> collectNames(
::c10d::PrefixStore store,
const worker_id_t selfId,
const std::string& selfName,
const int worldSize) {
std::vector<uint8_t> selfNameVector(
(uint8_t*)selfName.c_str(),
(uint8_t*)selfName.c_str() + selfName.length());
store.set(c10::to_string(selfId), selfNameVector);
std::unordered_map<std::string, worker_id_t> nameToId;
nameToId.reserve(worldSize);
nameToId.emplace(selfName, selfId);
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
if (workerId == selfId) {
continue;
}
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
std::string workerName(
(char*)workerNameVector.data(), workerNameVector.size());
TORCH_CHECK(
nameToId.find(workerName) == nameToId.end(),
"RPC worker name ",
workerName,
" is not unique. Workers ",
nameToId.find(workerName)->second,
" and ",
workerId,
" share the same name.");
nameToId.emplace(workerName, workerId);
}
return nameToId;
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,10 +1,8 @@
#pragma once
#include <c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
@ -16,33 +14,7 @@ std::unordered_map<std::string, worker_id_t> collectNames(
::c10d::PrefixStore store,
const worker_id_t selfId,
const std::string& selfName,
const int worldSize) {
std::vector<uint8_t> selfNameVector(
(uint8_t*)selfName.c_str(),
(uint8_t*)selfName.c_str() + selfName.length());
store.set(c10::to_string(selfId), selfNameVector);
std::unordered_map<std::string, worker_id_t> nameToId;
nameToId.reserve(worldSize);
nameToId.emplace(selfName, selfId);
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
if (workerId == selfId) {
continue;
}
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
std::string workerName(
(char*)workerNameVector.data(), workerNameVector.size());
TORCH_CHECK(
nameToId.find(workerName) == nameToId.end(),
"RPC worker name ",
workerName,
" is not unique.");
nameToId.emplace(workerName, workerId);
}
return nameToId;
}
const int worldSize);
} // namespace rpc
} // namespace distributed

View File

@ -542,11 +542,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
py::cast(kDefaultNumSendRecvThreads);
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
.def(py::init([](std::string workerName,
.def(py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout) {
return std::make_unique<ProcessGroupAgent>(
store,
std::move(workerName),
pg,
numSendRecvThreads,

View File

@ -3,6 +3,7 @@
#include <c10/util/C++17.h>
#include <c10d/ProcessGroup.hpp>
#include <fmt/format.h>
#include <torch/csrc/distributed/rpc/agent_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
@ -57,38 +58,8 @@ const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
void ProcessGroupAgent::collectNames() {
const std::string& workerName = workerInfo_.name_;
const auto worldSize = pg_->getSize();
// use c10d allgather to collect names
torch::Tensor nameTensor =
torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar);
memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length());
std::vector<torch::Tensor> inputName = {nameTensor};
std::vector<std::vector<torch::Tensor>> outputNames(1);
for (int i = 0; i < worldSize; ++i) {
outputNames[0].emplace_back(
torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar}));
}
pg_->allgather(outputNames, inputName)->wait();
// convert collected name tensors into string names
for (worker_id_t i = 0; i < worldSize; ++i) {
torch::Tensor& tensor = outputNames[0][i];
std::string peerName((const char*)tensor.storage().data<signed char>());
TORCH_CHECK(
nameMap_.find(peerName) == nameMap_.end(),
"RpcAgent name ",
peerName,
" is not unique.");
nameMap_[std::move(peerName)] = i;
}
}
ProcessGroupAgent::ProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
@ -109,7 +80,12 @@ ProcessGroupAgent::ProcessGroupAgent(
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
collectNames();
nameMap_ = collectNames(
::c10d::PrefixStore("names", store),
workerInfo_.id_,
workerInfo_.name_,
pg_->getSize());
auto workerRankIter = nameMap_.find(workerInfo_.name_);
TORCH_CHECK(
workerRankIter != nameMap_.end(),

View File

@ -1,6 +1,7 @@
#pragma once
#include <c10/core/thread_pool.h>
#include <c10d/PrefixStore.hpp>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
@ -60,6 +61,7 @@ struct RecvWork {
class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
@ -148,7 +150,6 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
FutureInfo() = delete;
};
void collectNames();
// handle a SendWork request. This serializes the payload inside the work
// object, and sends the message to the receiver using the underlying
// ProcessGroup.

View File

@ -11,6 +11,7 @@ std::string fromVec(const std::vector<char>& vec) {
}
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
@ -19,6 +20,7 @@ FaultyProcessGroupAgent::FaultyProcessGroupAgent(
const std::unordered_map<std::string, float>& messageTypesToDelay,
int failNumSends)
: ProcessGroupAgent(
store,
std::move(workerName),
std::move(pg),
numSendRecvThreads,

View File

@ -34,6 +34,7 @@ struct FaultyProcessGroupRpcBackendOptions
class FaultyProcessGroupAgent : public ProcessGroupAgent {
public:
FaultyProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads,

View File

@ -67,6 +67,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
.def(
py::init<
const c10::intrusive_ptr<::c10d::Store>,
std::string,
c10::intrusive_ptr<::c10d::ProcessGroup>,
int,
@ -74,6 +75,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
const std::vector<std::string>&,
const std::unordered_map<std::string, float>&,
int>(),
py::arg("store"),
py::arg("name"),
py::arg("process_group"),
py::arg("num_send_recv_threads"),

View File

@ -63,6 +63,7 @@ def _faulty_process_group_init_backend_handler(
)
return FaultyProcessGroupAgent(
store,
name,
group,
rpc_backend_options.num_send_recv_threads,

View File

@ -157,6 +157,7 @@ def _process_group_init_backend_handler(
# TODO: add try-except and destroy _agent in all processes if any fails.
return ProcessGroupAgent(
store,
name,
group,
rpc_backend_options.num_send_recv_threads,