mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Use Store collect and verify names in all RPC agents (#53209)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53209 closes #40048 Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D26791524 Pulled By: mrshenli fbshipit-source-id: fc75589f9707014334fcfae6f05af3c04217783b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
affdcce833
commit
c7b1979b6b
@ -353,6 +353,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||||||
# up being empty. Downstream targets should also add a #ifdef guard.
|
# up being empty. Downstream targets should also add a #ifdef guard.
|
||||||
if(NOT WIN32)
|
if(NOT WIN32)
|
||||||
add_library(process_group_agent
|
add_library(process_group_agent
|
||||||
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
|
||||||
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
|
||||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp"
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp"
|
||||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h"
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.h"
|
||||||
)
|
)
|
||||||
@ -361,6 +363,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||||||
|
|
||||||
if(USE_TENSORPIPE)
|
if(USE_TENSORPIPE)
|
||||||
add_library(tensorpipe_agent
|
add_library(tensorpipe_agent
|
||||||
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.cpp"
|
||||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/agent_utils.h"
|
||||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
|
||||||
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"
|
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"
|
||||||
|
@ -26,6 +26,7 @@ class TestE2EProcessGroup : public TestE2EBase {
|
|||||||
store, 0, numWorkers, options);
|
store, 0, numWorkers, options);
|
||||||
|
|
||||||
rpcAgent = std::make_shared<ProcessGroupAgent>(
|
rpcAgent = std::make_shared<ProcessGroupAgent>(
|
||||||
|
store,
|
||||||
"worker",
|
"worker",
|
||||||
pg,
|
pg,
|
||||||
std::max(16U, std::thread::hardware_concurrency()),
|
std::max(16U, std::thread::hardware_concurrency()),
|
||||||
|
@ -611,6 +611,7 @@ libtorch_python_distributed_core_sources = [
|
|||||||
|
|
||||||
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
|
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
|
||||||
"torch/csrc/distributed/autograd/init.cpp",
|
"torch/csrc/distributed/autograd/init.cpp",
|
||||||
|
"torch/csrc/distributed/rpc/agent_utils.cpp",
|
||||||
"torch/csrc/distributed/rpc/init.cpp",
|
"torch/csrc/distributed/rpc/init.cpp",
|
||||||
"torch/csrc/distributed/rpc/process_group_agent.cpp",
|
"torch/csrc/distributed/rpc/process_group_agent.cpp",
|
||||||
"torch/csrc/distributed/rpc/py_rref.cpp",
|
"torch/csrc/distributed/rpc/py_rref.cpp",
|
||||||
|
@ -77,6 +77,7 @@ class ProcessGroupRpcBackendOptions(RpcBackendOptions):
|
|||||||
class ProcessGroupAgent(RpcAgent):
|
class ProcessGroupAgent(RpcAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
store: Store,
|
||||||
worker_name: str,
|
worker_name: str,
|
||||||
pg: ProcessGroup,
|
pg: ProcessGroup,
|
||||||
numSendRecvThreads: int,
|
numSendRecvThreads: int,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from ._distributed_c10d import ProcessGroup
|
from ._distributed_c10d import ProcessGroup, Store
|
||||||
from ._distributed_rpc import ProcessGroupAgent, ProcessGroupRpcBackendOptions, WorkerInfo
|
from ._distributed_rpc import ProcessGroupAgent, ProcessGroupRpcBackendOptions, WorkerInfo
|
||||||
from typing import List, Dict, overload
|
from typing import List, Dict, overload
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
@ -23,6 +23,7 @@ class FaultyProcessGroupRpcBackendOptions(ProcessGroupRpcBackendOptions):
|
|||||||
class FaultyProcessGroupAgent(ProcessGroupAgent):
|
class FaultyProcessGroupAgent(ProcessGroupAgent):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
store: Store,
|
||||||
name: str,
|
name: str,
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
num_send_recv_threads: int,
|
num_send_recv_threads: int,
|
||||||
|
45
torch/csrc/distributed/rpc/agent_utils.cpp
Normal file
45
torch/csrc/distributed/rpc/agent_utils.cpp
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace distributed {
|
||||||
|
namespace rpc {
|
||||||
|
|
||||||
|
std::unordered_map<std::string, worker_id_t> collectNames(
|
||||||
|
::c10d::PrefixStore store,
|
||||||
|
const worker_id_t selfId,
|
||||||
|
const std::string& selfName,
|
||||||
|
const int worldSize) {
|
||||||
|
std::vector<uint8_t> selfNameVector(
|
||||||
|
(uint8_t*)selfName.c_str(),
|
||||||
|
(uint8_t*)selfName.c_str() + selfName.length());
|
||||||
|
store.set(c10::to_string(selfId), selfNameVector);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, worker_id_t> nameToId;
|
||||||
|
nameToId.reserve(worldSize);
|
||||||
|
nameToId.emplace(selfName, selfId);
|
||||||
|
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
|
||||||
|
if (workerId == selfId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
|
||||||
|
std::string workerName(
|
||||||
|
(char*)workerNameVector.data(), workerNameVector.size());
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
nameToId.find(workerName) == nameToId.end(),
|
||||||
|
"RPC worker name ",
|
||||||
|
workerName,
|
||||||
|
" is not unique. Workers ",
|
||||||
|
nameToId.find(workerName)->second,
|
||||||
|
" and ",
|
||||||
|
workerId,
|
||||||
|
" share the same name.");
|
||||||
|
|
||||||
|
nameToId.emplace(workerName, workerId);
|
||||||
|
}
|
||||||
|
return nameToId;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace rpc
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace torch
|
@ -1,10 +1,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
|
||||||
#include <c10d/PrefixStore.hpp>
|
#include <c10d/PrefixStore.hpp>
|
||||||
#include <torch/csrc/distributed/rpc/utils.h>
|
#include <torch/csrc/distributed/rpc/utils.h>
|
||||||
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
@ -16,33 +14,7 @@ std::unordered_map<std::string, worker_id_t> collectNames(
|
|||||||
::c10d::PrefixStore store,
|
::c10d::PrefixStore store,
|
||||||
const worker_id_t selfId,
|
const worker_id_t selfId,
|
||||||
const std::string& selfName,
|
const std::string& selfName,
|
||||||
const int worldSize) {
|
const int worldSize);
|
||||||
std::vector<uint8_t> selfNameVector(
|
|
||||||
(uint8_t*)selfName.c_str(),
|
|
||||||
(uint8_t*)selfName.c_str() + selfName.length());
|
|
||||||
store.set(c10::to_string(selfId), selfNameVector);
|
|
||||||
|
|
||||||
std::unordered_map<std::string, worker_id_t> nameToId;
|
|
||||||
nameToId.reserve(worldSize);
|
|
||||||
nameToId.emplace(selfName, selfId);
|
|
||||||
for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) {
|
|
||||||
if (workerId == selfId) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::vector<uint8_t> workerNameVector = store.get(c10::to_string(workerId));
|
|
||||||
std::string workerName(
|
|
||||||
(char*)workerNameVector.data(), workerNameVector.size());
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
nameToId.find(workerName) == nameToId.end(),
|
|
||||||
"RPC worker name ",
|
|
||||||
workerName,
|
|
||||||
" is not unique.");
|
|
||||||
|
|
||||||
nameToId.emplace(workerName, workerId);
|
|
||||||
}
|
|
||||||
return nameToId;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
@ -542,11 +542,13 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
|
|||||||
py::cast(kDefaultNumSendRecvThreads);
|
py::cast(kDefaultNumSendRecvThreads);
|
||||||
|
|
||||||
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
|
shared_ptr_class_<ProcessGroupAgent>(module, "ProcessGroupAgent", rpcAgent)
|
||||||
.def(py::init([](std::string workerName,
|
.def(py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
|
std::string workerName,
|
||||||
const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
|
const c10::intrusive_ptr<::c10d::ProcessGroup>& pg,
|
||||||
int numSendRecvThreads,
|
int numSendRecvThreads,
|
||||||
std::chrono::milliseconds rpcTimeout) {
|
std::chrono::milliseconds rpcTimeout) {
|
||||||
return std::make_unique<ProcessGroupAgent>(
|
return std::make_unique<ProcessGroupAgent>(
|
||||||
|
store,
|
||||||
std::move(workerName),
|
std::move(workerName),
|
||||||
pg,
|
pg,
|
||||||
numSendRecvThreads,
|
numSendRecvThreads,
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
#include <c10d/ProcessGroup.hpp>
|
#include <c10d/ProcessGroup.hpp>
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
||||||
#include <torch/csrc/distributed/rpc/utils.h>
|
#include <torch/csrc/distributed/rpc/utils.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
@ -57,38 +58,8 @@ const std::string kClientActiveCalls = "agent.client_active_calls";
|
|||||||
const std::string kServerActiveCalls = "agent.server_active_calls";
|
const std::string kServerActiveCalls = "agent.server_active_calls";
|
||||||
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
||||||
|
|
||||||
void ProcessGroupAgent::collectNames() {
|
|
||||||
const std::string& workerName = workerInfo_.name_;
|
|
||||||
const auto worldSize = pg_->getSize();
|
|
||||||
|
|
||||||
// use c10d allgather to collect names
|
|
||||||
torch::Tensor nameTensor =
|
|
||||||
torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar);
|
|
||||||
memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length());
|
|
||||||
std::vector<torch::Tensor> inputName = {nameTensor};
|
|
||||||
std::vector<std::vector<torch::Tensor>> outputNames(1);
|
|
||||||
for (int i = 0; i < worldSize; ++i) {
|
|
||||||
outputNames[0].emplace_back(
|
|
||||||
torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar}));
|
|
||||||
}
|
|
||||||
pg_->allgather(outputNames, inputName)->wait();
|
|
||||||
|
|
||||||
// convert collected name tensors into string names
|
|
||||||
for (worker_id_t i = 0; i < worldSize; ++i) {
|
|
||||||
torch::Tensor& tensor = outputNames[0][i];
|
|
||||||
std::string peerName((const char*)tensor.storage().data<signed char>());
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
nameMap_.find(peerName) == nameMap_.end(),
|
|
||||||
"RpcAgent name ",
|
|
||||||
peerName,
|
|
||||||
" is not unique.");
|
|
||||||
|
|
||||||
nameMap_[std::move(peerName)] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ProcessGroupAgent::ProcessGroupAgent(
|
ProcessGroupAgent::ProcessGroupAgent(
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
std::string workerName,
|
std::string workerName,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||||
int numSendRecvThreads,
|
int numSendRecvThreads,
|
||||||
@ -109,7 +80,12 @@ ProcessGroupAgent::ProcessGroupAgent(
|
|||||||
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
|
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
|
||||||
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
|
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
|
||||||
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
|
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
|
||||||
collectNames();
|
|
||||||
|
nameMap_ = collectNames(
|
||||||
|
::c10d::PrefixStore("names", store),
|
||||||
|
workerInfo_.id_,
|
||||||
|
workerInfo_.name_,
|
||||||
|
pg_->getSize());
|
||||||
auto workerRankIter = nameMap_.find(workerInfo_.name_);
|
auto workerRankIter = nameMap_.find(workerInfo_.name_);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
workerRankIter != nameMap_.end(),
|
workerRankIter != nameMap_.end(),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/thread_pool.h>
|
#include <c10/core/thread_pool.h>
|
||||||
|
#include <c10d/PrefixStore.hpp>
|
||||||
#include <c10d/ProcessGroup.hpp>
|
#include <c10d/ProcessGroup.hpp>
|
||||||
#include <torch/csrc/distributed/rpc/request_callback.h>
|
#include <torch/csrc/distributed/rpc/request_callback.h>
|
||||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||||
@ -60,6 +61,7 @@ struct RecvWork {
|
|||||||
class TORCH_API ProcessGroupAgent : public RpcAgent {
|
class TORCH_API ProcessGroupAgent : public RpcAgent {
|
||||||
public:
|
public:
|
||||||
ProcessGroupAgent(
|
ProcessGroupAgent(
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
std::string workerName,
|
std::string workerName,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||||
int numSendRecvThreads,
|
int numSendRecvThreads,
|
||||||
@ -148,7 +150,6 @@ class TORCH_API ProcessGroupAgent : public RpcAgent {
|
|||||||
FutureInfo() = delete;
|
FutureInfo() = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
void collectNames();
|
|
||||||
// handle a SendWork request. This serializes the payload inside the work
|
// handle a SendWork request. This serializes the payload inside the work
|
||||||
// object, and sends the message to the receiver using the underlying
|
// object, and sends the message to the receiver using the underlying
|
||||||
// ProcessGroup.
|
// ProcessGroup.
|
||||||
|
@ -11,6 +11,7 @@ std::string fromVec(const std::vector<char>& vec) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
std::string workerName,
|
std::string workerName,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||||
int numSendRecvThreads,
|
int numSendRecvThreads,
|
||||||
@ -19,6 +20,7 @@ FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
|||||||
const std::unordered_map<std::string, float>& messageTypesToDelay,
|
const std::unordered_map<std::string, float>& messageTypesToDelay,
|
||||||
int failNumSends)
|
int failNumSends)
|
||||||
: ProcessGroupAgent(
|
: ProcessGroupAgent(
|
||||||
|
store,
|
||||||
std::move(workerName),
|
std::move(workerName),
|
||||||
std::move(pg),
|
std::move(pg),
|
||||||
numSendRecvThreads,
|
numSendRecvThreads,
|
||||||
|
@ -34,6 +34,7 @@ struct FaultyProcessGroupRpcBackendOptions
|
|||||||
class FaultyProcessGroupAgent : public ProcessGroupAgent {
|
class FaultyProcessGroupAgent : public ProcessGroupAgent {
|
||||||
public:
|
public:
|
||||||
FaultyProcessGroupAgent(
|
FaultyProcessGroupAgent(
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||||
std::string workerName,
|
std::string workerName,
|
||||||
c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||||
int numSendRecvThreads,
|
int numSendRecvThreads,
|
||||||
|
@ -67,6 +67,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
|||||||
module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
|
module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
|
||||||
.def(
|
.def(
|
||||||
py::init<
|
py::init<
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>,
|
||||||
std::string,
|
std::string,
|
||||||
c10::intrusive_ptr<::c10d::ProcessGroup>,
|
c10::intrusive_ptr<::c10d::ProcessGroup>,
|
||||||
int,
|
int,
|
||||||
@ -74,6 +75,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
|
|||||||
const std::vector<std::string>&,
|
const std::vector<std::string>&,
|
||||||
const std::unordered_map<std::string, float>&,
|
const std::unordered_map<std::string, float>&,
|
||||||
int>(),
|
int>(),
|
||||||
|
py::arg("store"),
|
||||||
py::arg("name"),
|
py::arg("name"),
|
||||||
py::arg("process_group"),
|
py::arg("process_group"),
|
||||||
py::arg("num_send_recv_threads"),
|
py::arg("num_send_recv_threads"),
|
||||||
|
@ -63,6 +63,7 @@ def _faulty_process_group_init_backend_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return FaultyProcessGroupAgent(
|
return FaultyProcessGroupAgent(
|
||||||
|
store,
|
||||||
name,
|
name,
|
||||||
group,
|
group,
|
||||||
rpc_backend_options.num_send_recv_threads,
|
rpc_backend_options.num_send_recv_threads,
|
||||||
|
@ -157,6 +157,7 @@ def _process_group_init_backend_handler(
|
|||||||
|
|
||||||
# TODO: add try-except and destroy _agent in all processes if any fails.
|
# TODO: add try-except and destroy _agent in all processes if any fails.
|
||||||
return ProcessGroupAgent(
|
return ProcessGroupAgent(
|
||||||
|
store,
|
||||||
name,
|
name,
|
||||||
group,
|
group,
|
||||||
rpc_backend_options.num_send_recv_threads,
|
rpc_backend_options.num_send_recv_threads,
|
||||||
|
Reference in New Issue
Block a user