mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use Store collect and verify names in all RPC agents (#53209)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53209 closes #40048 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26791524 Pulled By: mrshenli fbshipit-source-id: fc75589f9707014334fcfae6f05af3c04217783b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
affdcce833
commit
c7b1979b6b
@ -353,6 +353,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
# up being empty. Downstream targets should also add a #ifdef guard.
|
||||
if(NOT WIN32)
|
||||
add_library(process_group_agent
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h"
|
||||
)
|
||||
@ -361,6 +363,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
|
||||
if(USE_TENSORPIPE)
|
||||
add_library(tensorpipe_agent
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
|
||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"
|
||||
|
@ -26,6 +26,7 @@ class TestE2EProcessGroup : public TestE2EBase {
|
||||
store, 0, numWorkers, options);
|
||||
|
||||
rpcAgent = std::make_shared<ProcessGroupAgent>(
|
||||
store,
|
||||
"worker",
|
||||
pg,
|
||||
std::max(16U, std::thread::hardware_concurrency()),
|
||||
|
@ -611,6 +611,7 @@ libtorch_python_distributed_core_sources = [
|
||||
|
||||
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
|
||||
"torch/csrc/distributed/autograd/init.cpp",
|
||||
"torch/csrc/distributed/rpc/agent_utils.cpp",
|
||||
"torch/csrc/distributed/rpc/init.cpp",
|
||||
"torch/csrc/distributed/rpc/process_group_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/py_rref.cpp",
|
||||
|
@ -77,6 +77,7 @@ class ProcessGroupRpcBackendOptions(RpcBackendOptions):
|
||||
class ProcessGroupAgent(RpcAgent):
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
worker_name: str,
|
||||
pg: ProcessGroup,
|
||||
numSendRecvThreads: int,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from ._distributed_c10d import ProcessGroup
|
||||
from ._distributed_c10d import ProcessGroup, Store
|
||||
from ._distributed_rpc import ProcessGroupAgent, ProcessGroupRpcBackendOptions, WorkerInfo
|
||||
from typing import List, Dict, overload
|
||||
from datetime import timedelta
|
||||
@ -23,6 +23,7 @@ class FaultyProcessGroupRpcBackendOptions(ProcessGroupRpcBackendOptions):
|
||||
class FaultyProcessGroupAgent(ProcessGroupAgent):
|
||||
def __init__(
|
||||
self,
|
||||
store: Store,
|
||||
name: str,
|
||||
process_group: ProcessGroup,
|
||||
num_send_recv_threads: int,
|
||||
|
45
torch/csrc/distributed/rpc/agent_utils.cpp
Normal file
45
torch/csrc/distributed/rpc/agent_utils.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
std::unordered_map<std::string, worker_id_t> collectNames(
|
||||
::c10d::PrefixStore store,
|
||||
const worker_id_t selfId,
|
||||
const std::string& selfName,
|
||||
const int worldSize) {
|
||||
std::vector<uint8_t> selfNameVector(
|
||||
(uint8_t*)selfName.c_str(),
|
||||
(uint8_t*)selfName.c_str() + selfName.length());
|
||||
store.set(c10::to_string(selfId), selfNameVector);
|
||||
|
||||
std::unordered_map<std::string, worker_id_t> nameToId;
|
||||
nameToId.reserve(worldSize);
|
||||
nameToId.emplace(selfName, selfId);
|
||||
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
|
||||
if (workerId == selfId) {
|
||||
continue;
|
||||
}
|
||||
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
|
||||
std::string workerName(
|
||||
(char*)workerNameVector.data(), workerNameVector.size());
|
||||
|
||||
TORCH_CHECK(
|
||||
nameToId.find(workerName) == nameToId.end(),
|
||||
"RPC worker name ",
|
||||
workerName,
|
||||
" is not unique. Workers ",
|
||||
nameToId.find(workerName)->second,
|
||||
" and ",
|
||||
workerId,
|
||||
" share the same name.");
|
||||
|
||||
nameToId.emplace(workerName, workerId);
|
||||
}
|
||||
return nameToId;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -1,10 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
|
||||
#include <c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
@ -16,33 +14,7 @@ std::unordered_map<std::string, worker_id_t> collectNames(
|
||||
::c10d::PrefixStore store,
|
||||
const worker_id_t selfId,
|
||||
const std::string& selfName,
|
||||
const int worldSize) {
|
||||
std::vector<uint8_t> selfNameVector(
|
||||
(uint8_t*)selfName.c_str(),
|
||||
(uint8_t*)selfName.c_str() + selfName.length());
|
||||
store.set(c10::to_string(selfId), selfNameVector);
|
||||
|
||||
std::unordered_map<std::string, worker_id_t> nameToId;
|
||||
nameToId.reserve(worldSize);
|
||||
nameToId.emplace(selfName, selfId);
|
||||
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
|
||||
if (workerId == selfId) {
|
||||
continue;
|
||||
}
|
||||
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
|
||||
std::string workerName(
|
||||
(char*)workerNameVector.data(), workerNameVector.size());
|
||||
|
||||
TORCH_CHECK(
|
||||
nameToId.find(workerName) == nameToId.end(),
|
||||
"RPC worker name ",
|
||||
workerName,
|
||||
" is not unique.");
|
||||
|
||||
nameToId.emplace(workerName, workerId);
|
||||
}
|
||||
return nameToId;
|
||||
}
|
||||
const int worldSize);
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
@ -542,11 +542,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
||||
py::cast(kDefaultNumSendRecvThreads);
|
||||
|
||||
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
|
||||
.def(py::init([](std::string workerName,
|
||||
.def(py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
|
||||
int numSendRecvThreads,
|
||||
std::chrono::milliseconds rpcTimeout) {
|
||||
return std::make_unique<ProcessGroupAgent>(
|
||||
store,
|
||||
std::move(workerName),
|
||||
pg,
|
||||
numSendRecvThreads,
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
||||
namespace torch {
|
||||
@ -57,38 +58,8 @@ const std::string kClientActiveCalls = "agent.client_active_calls";
|
||||
const std::string kServerActiveCalls = "agent.server_active_calls";
|
||||
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
||||
|
||||
void ProcessGroupAgent::collectNames() {
|
||||
const std::string& workerName = workerInfo_.name_;
|
||||
const auto worldSize = pg_->getSize();
|
||||
|
||||
// use c10d allgather to collect names
|
||||
torch::Tensor nameTensor =
|
||||
torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar);
|
||||
memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length());
|
||||
std::vector<torch::Tensor> inputName = {nameTensor};
|
||||
std::vector<std::vector<torch::Tensor>> outputNames(1);
|
||||
for (int i = 0; i < worldSize; ++i) {
|
||||
outputNames[0].emplace_back(
|
||||
torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar}));
|
||||
}
|
||||
pg_->allgather(outputNames, inputName)->wait();
|
||||
|
||||
// convert collected name tensors into string names
|
||||
for (worker_id_t i = 0; i < worldSize; ++i) {
|
||||
torch::Tensor& tensor = outputNames[0][i];
|
||||
std::string peerName((const char*)tensor.storage().data<signed char>());
|
||||
|
||||
TORCH_CHECK(
|
||||
nameMap_.find(peerName) == nameMap_.end(),
|
||||
"RpcAgent name ",
|
||||
peerName,
|
||||
" is not unique.");
|
||||
|
||||
nameMap_[std::move(peerName)] = i;
|
||||
}
|
||||
}
|
||||
|
||||
ProcessGroupAgent::ProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
@ -109,7 +80,12 @@ ProcessGroupAgent::ProcessGroupAgent(
|
||||
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
|
||||
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
|
||||
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
|
||||
collectNames();
|
||||
|
||||
nameMap_ = collectNames(
|
||||
::c10d::PrefixStore("names", store),
|
||||
workerInfo_.id_,
|
||||
workerInfo_.name_,
|
||||
pg_->getSize());
|
||||
auto workerRankIter = nameMap_.find(workerInfo_.name_);
|
||||
TORCH_CHECK(
|
||||
workerRankIter != nameMap_.end(),
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10d/PrefixStore.hpp>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/rpc/request_callback.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
@ -60,6 +61,7 @@ struct RecvWork {
|
||||
class TORCH_API ProcessGroupAgent : public RpcAgent {
|
||||
public:
|
||||
ProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
@ -148,7 +150,6 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
|
||||
FutureInfo() = delete;
|
||||
};
|
||||
|
||||
void collectNames();
|
||||
// handle a SendWork request. This serializes the payload inside the work
|
||||
// object, and sends the message to the receiver using the underlying
|
||||
// ProcessGroup.
|
||||
|
@ -11,6 +11,7 @@ std::string fromVec(const std::vector<char>& vec) {
|
||||
}
|
||||
|
||||
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
@ -19,6 +20,7 @@ FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
||||
const std::unordered_map<std::string, float>& messageTypesToDelay,
|
||||
int failNumSends)
|
||||
: ProcessGroupAgent(
|
||||
store,
|
||||
std::move(workerName),
|
||||
std::move(pg),
|
||||
numSendRecvThreads,
|
||||
|
@ -34,6 +34,7 @@ struct FaultyProcessGroupRpcBackendOptions
|
||||
class FaultyProcessGroupAgent : public ProcessGroupAgent {
|
||||
public:
|
||||
FaultyProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
|
@ -67,6 +67,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
||||
module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
|
||||
.def(
|
||||
py::init<
|
||||
const c10::intrusive_ptr<::c10d::Store>,
|
||||
std::string,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup>,
|
||||
int,
|
||||
@ -74,6 +75,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
||||
const std::vector<std::string>&,
|
||||
const std::unordered_map<std::string, float>&,
|
||||
int>(),
|
||||
py::arg("store"),
|
||||
py::arg("name"),
|
||||
py::arg("process_group"),
|
||||
py::arg("num_send_recv_threads"),
|
||||
|
@ -63,6 +63,7 @@ def _faulty_process_group_init_backend_handler(
|
||||
)
|
||||
|
||||
return FaultyProcessGroupAgent(
|
||||
store,
|
||||
name,
|
||||
group,
|
||||
rpc_backend_options.num_send_recv_threads,
|
||||
|
@ -157,6 +157,7 @@ def _process_group_init_backend_handler(
|
||||
|
||||
# TODO: add try-except and destroy _agent in all processes if any fails.
|
||||
return ProcessGroupAgent(
|
||||
store,
|
||||
name,
|
||||
group,
|
||||
rpc_backend_options.num_send_recv_threads,
|
||||
|
Reference in New Issue
Block a user