mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove process_group_agent and faulty_process_group_agent files (#62985)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62985 Remove the process_group_agent and faulty_process_group_agent code now that PROCESS_GROUP backend has been deprecated for RPC (https://github.com/pytorch/pytorch/issues/55615). Discussed with xush6528 that it was okay to remove ProcessGroupAgentTest and ProcessGroupAgentBench which depended on process_group_agent. Test Plan: CI tests Reviewed By: pritamdamania87 Differential Revision: D30195576 fbshipit-source-id: 8b4381cffadb868b19d481198015d0a67b205811
This commit is contained in:
committed by
Facebook GitHub Bot
parent
790553811c
commit
4d0497034c
@ -10,7 +10,6 @@ set(TORCH_RPC_TEST_DEPENDENCY_LIBS
|
||||
|
||||
if(USE_GLOO)
|
||||
list(APPEND TORCH_RPC_TEST_SOURCES
|
||||
${TORCH_RPC_TEST_DIR}/test_e2e_process_group.cpp
|
||||
${TORCH_RPC_TEST_DIR}/test_e2e_tensorpipe.cpp
|
||||
)
|
||||
endif()
|
||||
|
@ -1,47 +0,0 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "e2e_test_base.h"
|
||||
|
||||
#include <c10d/ProcessGroupGloo.hpp>
|
||||
#include <torch/csrc/distributed/rpc/process_group_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
|
||||
class TestE2EProcessGroup : public TestE2EBase {
|
||||
protected:
|
||||
void buildRpcAgent() override {
|
||||
auto options = c10d::ProcessGroupGloo::Options::create();
|
||||
options->devices.push_back(
|
||||
::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
|
||||
std::chrono::milliseconds rpcTimeout(30000);
|
||||
options->timeout = rpcTimeout;
|
||||
|
||||
// Initialize server rpc agent.
|
||||
auto pg = c10::make_intrusive<c10d::ProcessGroupGloo>(
|
||||
store, 0, numWorkers, options);
|
||||
|
||||
rpcAgent = std::make_shared<ProcessGroupAgent>(
|
||||
store,
|
||||
"worker",
|
||||
pg,
|
||||
std::max(16U, std::thread::hardware_concurrency()),
|
||||
rpcTimeout,
|
||||
std::make_unique<RequestCallbackNoPython>());
|
||||
}
|
||||
};
|
||||
|
||||
// End to end training loop test in C++ so that we can run LSAN on this test to
|
||||
// catch memory leaks. Enabling LSAN with python multiprocessing has been
|
||||
// challenging and we don't have a good solution yet.
|
||||
TEST_F(TestE2EProcessGroup, TestTrainingLoop) {
|
||||
runTrainingLoop();
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -371,7 +371,6 @@ libtorch_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/rpc/python_call.cpp",
|
||||
"torch/csrc/distributed/rpc/python_remote_call.cpp",
|
||||
"torch/csrc/distributed/rpc/python_resp.cpp",
|
||||
"torch/csrc/distributed/rpc/process_group_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/request_callback.cpp",
|
||||
"torch/csrc/distributed/rpc/request_callback_no_python.cpp",
|
||||
"torch/csrc/distributed/rpc/rpc_agent.cpp",
|
||||
@ -383,7 +382,6 @@ libtorch_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/rpc/script_resp.cpp",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_utils.cpp",
|
||||
"torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp",
|
||||
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
|
||||
"torch/csrc/distributed/rpc/types.cpp",
|
||||
|
@ -1,863 +0,0 @@
|
||||
#include <torch/csrc/distributed/rpc/process_group_agent.h>
|
||||
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <fmt/format.h>
|
||||
#include <torch/csrc/distributed/rpc/agent_utils.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
////////////////////////// MessageCounter /////////////////////////////////
|
||||
|
||||
ProcessGroupAgent::MessageCounter::MessageCounter(int worldSize)
|
||||
: counters_(worldSize) {}
|
||||
|
||||
void ProcessGroupAgent::MessageCounter::increment(int dst) {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
++counters_[dst];
|
||||
}
|
||||
|
||||
std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
return counters_;
|
||||
}
|
||||
|
||||
////////////////////////// MetricsTracker /////////////////////////////////
|
||||
|
||||
ProcessGroupAgent::AverageMetricsTracker::AverageMetricsTracker(
|
||||
std::string key,
|
||||
uint64_t currentSum,
|
||||
uint64_t currentCount)
|
||||
: key_(std::move(key)),
|
||||
currentSum_(currentSum),
|
||||
currentCount_(currentCount) {}
|
||||
|
||||
void ProcessGroupAgent::AverageMetricsTracker::addData(uint64_t dataPoint) {
|
||||
currentSum_ += dataPoint;
|
||||
++currentCount_;
|
||||
}
|
||||
|
||||
double ProcessGroupAgent::AverageMetricsTracker::computeAverage() {
|
||||
return currentCount_ == 0 ? 0 : currentSum_ / (double)currentCount_;
|
||||
}
|
||||
|
||||
//////////////////////// ProcessGroupAgent /////////////////////////////////
|
||||
|
||||
using steady_clock_time_point =
|
||||
std::chrono::time_point<std::chrono::steady_clock>;
|
||||
const steady_clock_time_point kInfiniteTimeoutTimePoint =
|
||||
std::chrono::time_point<std::chrono::steady_clock>::max();
|
||||
const std::string kNumPendingRequests = "agent.num_pending_requests";
|
||||
const std::string kThreadPoolSize = "agent.thread_pool_size";
|
||||
const std::string kNumIdleThreads = "agent.num_idle_threads";
|
||||
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
|
||||
const std::string kClientActiveCalls = "agent.client_active_calls";
|
||||
const std::string kServerActiveCalls = "agent.server_active_calls";
|
||||
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
|
||||
|
||||
ProcessGroupAgent::ProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
std::chrono::milliseconds rpcTimeout,
|
||||
std::unique_ptr<RequestCallback> cb)
|
||||
: RpcAgent(
|
||||
WorkerInfo(std::move(workerName), (int64_t)pg->getRank()),
|
||||
std::move(cb),
|
||||
rpcTimeout),
|
||||
pg_(std::move(pg)),
|
||||
sendCounts_(pg_->getSize()),
|
||||
recvCounts_(pg_->getSize()),
|
||||
nextId_(0),
|
||||
sendMutexes_(pg_->getSize()),
|
||||
threadPool_(numSendRecvThreads),
|
||||
timeoutThreadEnabled_{false} {
|
||||
// initialize metric info counters
|
||||
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
|
||||
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
|
||||
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
|
||||
|
||||
nameMap_ = collectNames(
|
||||
::c10d::PrefixStore("names", store),
|
||||
workerInfo_.id_,
|
||||
workerInfo_.name_,
|
||||
pg_->getSize());
|
||||
auto workerRankIter = nameMap_.find(workerInfo_.name_);
|
||||
TORCH_CHECK(
|
||||
workerRankIter != nameMap_.end(),
|
||||
"Failed to resolve worker "
|
||||
"name ",
|
||||
workerInfo_.name_,
|
||||
" to a ProcessGroup rank.");
|
||||
TORCH_CHECK(
|
||||
pg_->getRank() == workerRankIter->second,
|
||||
"Resolved worker rank ",
|
||||
workerRankIter->second,
|
||||
" does not match ProcessGroup rank ",
|
||||
pg_->getRank());
|
||||
|
||||
// tmp vector to sort names in rank's order
|
||||
const auto worldSize = pg_->getSize();
|
||||
std::vector<std::string> tmpWorkerIds(worldSize);
|
||||
for (auto& entry : nameMap_) {
|
||||
tmpWorkerIds[entry.second] = entry.first;
|
||||
}
|
||||
|
||||
allWorkerInfo_.reserve(worldSize);
|
||||
for (worker_id_t rank = 0; rank < worldSize; ++rank) {
|
||||
allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
|
||||
}
|
||||
}
|
||||
|
||||
ProcessGroupAgent::~ProcessGroupAgent() {
|
||||
if (rpcAgentRunning_) {
|
||||
shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(
|
||||
const std::string& workerName) const {
|
||||
const auto idIter = nameMap_.find(workerName);
|
||||
TORCH_CHECK(
|
||||
idIter != nameMap_.end(), "Unknown destination worker ", workerName);
|
||||
|
||||
return allWorkerInfo_[idIter->second];
|
||||
}
|
||||
|
||||
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const {
|
||||
TORCH_CHECK(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
|
||||
id >= 0 && id < allWorkerInfo_.size(),
|
||||
"Invalid destination: ",
|
||||
id);
|
||||
return allWorkerInfo_[id];
|
||||
}
|
||||
|
||||
std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
|
||||
return allWorkerInfo_;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::join(bool /* unused */) {
|
||||
sync();
|
||||
std::unique_lock<std::mutex> lock(futureMutex_);
|
||||
futureCV_.wait(
|
||||
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
|
||||
lock.unlock();
|
||||
pg_->barrier()->wait();
|
||||
}
|
||||
|
||||
bool ProcessGroupAgent::hasPendingMessage() {
|
||||
const auto worldSize = pg_->getSize();
|
||||
auto snapshot = std::make_unique<std::vector<int64_t>>();
|
||||
snapshot->reserve(2 * worldSize);
|
||||
auto recvSnapshot = recvCounts_.snapshot();
|
||||
auto sendSnapshot = sendCounts_.snapshot();
|
||||
snapshot->insert(
|
||||
snapshot->end(),
|
||||
std::make_move_iterator(recvSnapshot.begin()),
|
||||
std::make_move_iterator(recvSnapshot.end()));
|
||||
snapshot->insert(
|
||||
snapshot->end(),
|
||||
std::make_move_iterator(sendSnapshot.begin()),
|
||||
std::make_move_iterator(sendSnapshot.end()));
|
||||
|
||||
auto snapshotData = snapshot->data();
|
||||
auto deleteWhenDone = snapshot.release();
|
||||
std::vector<torch::Tensor> inputSnapshot = {torch::from_blob(
|
||||
snapshotData,
|
||||
{2, worldSize},
|
||||
[deleteWhenDone](void*) { delete deleteWhenDone; },
|
||||
{torch::kInt64})};
|
||||
// allgather both send and recv messages in one shot
|
||||
std::vector<std::vector<torch::Tensor>> outputSnapshots(1);
|
||||
|
||||
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
|
||||
for (const auto i : c10::irange(worldSize)) {
|
||||
outputSnapshots[0].emplace_back(
|
||||
torch::zeros({2, worldSize}, {torch::kInt64}));
|
||||
}
|
||||
|
||||
pg_->allgather(outputSnapshots, inputSnapshot)->wait();
|
||||
|
||||
// loop through all send/recv pairs to make sure that all sent messages are
|
||||
// processed.
|
||||
const auto& peerCounts = outputSnapshots[0];
|
||||
for (const auto from : c10::irange(worldSize)) {
|
||||
for (const auto to : c10::irange(worldSize)) {
|
||||
// peerCounts[x][0] is recv counts, and peerCounts[x][1] is send counts
|
||||
|
||||
const auto& sentCnt = peerCounts[from][1][to].data_ptr<int64_t>()[0];
|
||||
const auto& recvCnt = peerCounts[to][0][from].data_ptr<int64_t>()[0];
|
||||
// NB: we cannot throw an error when sentCnt < recvCnt here. Because, send
|
||||
// and recv counts on different workers are read in a distributed manner.
|
||||
// It is possible that the sender reads its send count before sending, but
|
||||
// the receive reads its recv count after receiving. Hence, both > and <
|
||||
// are valid states.
|
||||
if (sentCnt != recvCnt) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::sync() {
|
||||
// Block until all processes wants to sync.
|
||||
pg_->barrier()->wait();
|
||||
// block until all peers agree that all sent messages have been processed.
|
||||
do {
|
||||
// Finish all send/recv tasks in the thread pool
|
||||
threadPool_.waitWorkComplete();
|
||||
// As there could be nested RPC calls, or response callback could also
|
||||
// trigger more messages to be sent, we need to wait for the thread pool
|
||||
// again.
|
||||
} while (hasPendingMessage());
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::startImpl() {
|
||||
timeoutThreadEnabled_.store(true);
|
||||
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
|
||||
futureTimeoutThread_ =
|
||||
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::shutdownImpl() {
|
||||
LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
|
||||
<< ".";
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(futureMutex_);
|
||||
timeoutThreadEnabled_.store(false);
|
||||
}
|
||||
futureTimeoutCV_.notify_one();
|
||||
futureTimeoutThread_.join();
|
||||
// Abort listener thread to stop accepting new work. We need to interrupt the
|
||||
// recvWork->wait() call the listener loop may be blocked in before joining
|
||||
// the thread.
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(recvWorkMutex_);
|
||||
if (recvWork_) {
|
||||
recvWork_->abort();
|
||||
}
|
||||
}
|
||||
listenerThread_.join();
|
||||
// Abort any pending sends to any destination rank that have not been
|
||||
// completed.
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(pendingSendMutex_);
|
||||
for (auto& it : currentPendingSends_) {
|
||||
const auto& pendingSends = it.second;
|
||||
const auto dst = it.first;
|
||||
for (const auto& send : pendingSends) {
|
||||
if (!send->isCompleted()) {
|
||||
LOG(INFO) << "Worker " << RpcAgent::getWorkerInfo().id_
|
||||
<< " aborting pending send to destination rank " << dst;
|
||||
|
||||
send->abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Note: calling threadPool_.waitWorkComplete() after listenerThread.join() so
|
||||
// that we can finish any possible work enqueued into the thread pool, before
|
||||
// python RPC handler is shutdown (see shutdown in rpc/api.py).
|
||||
threadPool_.waitWorkComplete();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<JitFuture> ProcessGroupAgent::send(
|
||||
const WorkerInfo& to,
|
||||
c10::intrusive_ptr<Message> message,
|
||||
const float rpcTimeoutSeconds,
|
||||
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
|
||||
// Throw if we previously encountered an exception in ::listenLoop.
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);
|
||||
if (listenLoopException_) {
|
||||
std::rethrow_exception(listenLoopException_);
|
||||
}
|
||||
}
|
||||
|
||||
if (!rpcAgentRunning_.load()) {
|
||||
// We are trying to send but RPC has been shut down on this node. This can
|
||||
// happen if we are in a shutdown sequence but background threads are still
|
||||
// processing messages that result in send()s. Throw a descriptive error.
|
||||
auto err = c10::str(
|
||||
"Node ",
|
||||
RpcAgent::getWorkerInfo().id_,
|
||||
"tried to send() a message of type ",
|
||||
message->type(),
|
||||
" but RPC is no longer running on this node.");
|
||||
TORCH_CHECK(false, err);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
to.id_ < (worker_id_t)pg_->getSize(),
|
||||
"Destination rank is out of bound, got ",
|
||||
to.id_,
|
||||
", but world size is ",
|
||||
pg_->getRank());
|
||||
|
||||
auto requestId = nextId();
|
||||
auto future = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
|
||||
if (message->isRequest()) {
|
||||
// millisecond level precision of when request started.
|
||||
auto futureStartTime = std::chrono::steady_clock::now();
|
||||
// if passed in timeout is unset, then use the currently set default timeout
|
||||
// for all RPCs.
|
||||
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
|
||||
? getRpcTimeout()
|
||||
: std::chrono::milliseconds(
|
||||
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
|
||||
|
||||
// Prepare endTime from timeout. Set infinite timeout if
|
||||
// specified.
|
||||
steady_clock_time_point endTime = timeout.count() == 0
|
||||
? kInfiniteTimeoutTimePoint
|
||||
: futureStartTime + timeout;
|
||||
bool notifyThread = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
// Insert future into future map.
|
||||
futures_.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(requestId),
|
||||
std::forward_as_tuple(FutureInfo(future, endTime, to.id_, timeout)));
|
||||
// insert future into timeouts map to keep track of its timeout
|
||||
auto& requestIds = futureTimeouts_[endTime];
|
||||
requestIds.insert(requestId);
|
||||
// Signal the watchdog to monitor future timeouts if this is the first
|
||||
// future created or it has earlier end time than other futures in the
|
||||
// map.
|
||||
if (futureTimeouts_.begin()->first == endTime &&
|
||||
(requestIds.size() == 1)) {
|
||||
notifyThread = true;
|
||||
}
|
||||
}
|
||||
if (notifyThread) {
|
||||
// Notify the watchdog thread only after releasing the lock,
|
||||
// so watchdog can acquire lock on waking up.
|
||||
futureTimeoutCV_.notify_one();
|
||||
}
|
||||
message->setId(requestId);
|
||||
++clientActiveCalls_;
|
||||
} else {
|
||||
future->markCompleted(IValue());
|
||||
}
|
||||
|
||||
// Sending to ourselves: bypass the send logic and enqueue directly
|
||||
// to our receiving queue.
|
||||
if (to.id_ == (worker_id_t)pg_->getRank()) {
|
||||
sendToSelf(std::move(message));
|
||||
return future;
|
||||
}
|
||||
|
||||
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
|
||||
// longer be alive when the ``SendWork`` is executed. For example, the
|
||||
// application could query the ``WorkerInfo`` using name through the
|
||||
// ``RpcAgent::getWorkerInfo`` API, and pass the ``WorkerInfo`` back here, so
|
||||
// we have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerInfo``
|
||||
// reference on Python side could die before ``SendWork`` uses it, and Pybind
|
||||
// will not keep the Python reference alive even if it originally comes from
|
||||
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
|
||||
// C++ land.
|
||||
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::handleSend(const SendWork& work) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pessimizing-move)
|
||||
auto serializedPayload = std::make_unique<std::string>(std::move(
|
||||
wireSerialize(work.message_->payload(), work.message_->tensors())));
|
||||
|
||||
std::vector<torch::Tensor> preamble = {torch::tensor(
|
||||
{(int64_t)pg_->getRank(),
|
||||
(int64_t)serializedPayload->length(),
|
||||
(int64_t)work.message_->type(),
|
||||
(int64_t)work.message_->id()},
|
||||
{torch::kInt64})};
|
||||
|
||||
// ProcessGroup is not thread-safe when sending with the same tag,
|
||||
// hence the lock
|
||||
std::vector<c10::intrusive_ptr<c10d::ProcessGroup::Work>> pendingSends;
|
||||
const auto dst = work.to_.id_;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
auto serializedPayloadData = const_cast<char*>(serializedPayload->data());
|
||||
auto serializedPayloadSize = serializedPayload->size();
|
||||
std::string* deleteWhenDone = serializedPayload.release();
|
||||
std::vector<torch::Tensor> payload = {torch::from_blob(
|
||||
reinterpret_cast<void*>(serializedPayloadData),
|
||||
serializedPayloadSize,
|
||||
[deleteWhenDone](void*) { delete deleteWhenDone; },
|
||||
{torch::kChar})};
|
||||
pendingSends.reserve(2);
|
||||
|
||||
sendCounts_.increment(dst);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(sendMutexes_[dst]);
|
||||
pendingSends.emplace_back(pg_->send(preamble, dst, dst /* channelTag */));
|
||||
pendingSends.emplace_back(pg_->send(payload, dst, dst /* channelTag */));
|
||||
}
|
||||
// Write pendingSends to a global map so that they can be interrupted by
|
||||
// ::shutdown().
|
||||
{
|
||||
std::lock_guard<std::mutex> pendingSendGuard(pendingSendMutex_);
|
||||
for (auto& p : pendingSends) {
|
||||
currentPendingSends_[dst].insert(p);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& pendingSend : pendingSends) {
|
||||
if (!rpcAgentRunning_.load() || !pendingSend->wait()) {
|
||||
// Send was interrupted or RPC is not running.
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Erase the pending sends that we added since we have returned from wait.
|
||||
{
|
||||
std::lock_guard<std::mutex> pendingSendGuard(pendingSendMutex_);
|
||||
// NB: We cannot just erase all of currentPendingSends[dst], since this
|
||||
// might preemptively remove sends from other threads.
|
||||
auto& set = currentPendingSends_[dst];
|
||||
for (auto& p : pendingSends) {
|
||||
set.erase(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::sendToSelf(c10::intrusive_ptr<Message> message) {
|
||||
// NOLINTNEXTLINE(modernize-avoid-bind)
|
||||
threadPool_.run(std::bind(
|
||||
[this](c10::intrusive_ptr<Message> message) {
|
||||
// Unlike the other cases, need to add a tensor deleter, since the
|
||||
// data outlives the scope of this function. It's shared_ptr<> due
|
||||
// to c++11 lambda capture limitations with unique_ptr<>.
|
||||
std::unique_ptr<std::string> payload;
|
||||
try {
|
||||
payload = std::make_unique<std::string>(
|
||||
wireSerialize(message->payload(), message->tensors()));
|
||||
// only increment sendCounts when the message is indeed added into
|
||||
// local recv.
|
||||
sendCounts_.increment(pg_->getRank());
|
||||
} catch (std::exception& e) {
|
||||
markFutureWithError(message->id(), e.what());
|
||||
return;
|
||||
}
|
||||
const char* data = payload->data();
|
||||
size_t len = payload->length();
|
||||
std::string* delete_when_done = payload.release();
|
||||
enqueueRecv(RecvWork(
|
||||
getWorkerInfo(pg_->getRank()),
|
||||
message->type(),
|
||||
message->id(),
|
||||
torch::from_blob(
|
||||
(void*)data,
|
||||
len,
|
||||
[delete_when_done](void*) { delete delete_when_done; },
|
||||
{torch::kChar})));
|
||||
},
|
||||
std::move(message)));
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::enqueueSend(SendWork work) {
|
||||
// NB: this can be changed to use a native move capture when moved to C++14
|
||||
// NOLINTNEXTLINE(modernize-avoid-bind)
|
||||
threadPool_.run(std::bind(
|
||||
[this](const SendWork& work) {
|
||||
try {
|
||||
handleSend(work);
|
||||
} catch (std::exception& e) {
|
||||
auto errorStr = c10::str(
|
||||
"Encountered exception in ProcessGroupAgent::enqueueSend: ",
|
||||
e.what(),
|
||||
" on node: ",
|
||||
RpcAgent::getWorkerInfo().id_);
|
||||
auto exceptionMsg =
|
||||
rpc::createExceptionResponse(errorStr, work.message_->id());
|
||||
if (work.message_->isRequest()) {
|
||||
// Mark the future with corresponding to this request with an error.
|
||||
markFutureWithError(*exceptionMsg);
|
||||
} else if (work.message_->isResponse()) {
|
||||
// Try sending the error along.
|
||||
handleSend(SendWork(work.to_, std::move(exceptionMsg)));
|
||||
}
|
||||
}
|
||||
},
|
||||
std::move(work)));
|
||||
}
|
||||
|
||||
bool ProcessGroupAgent::handleRecv(RecvWork& work) {
|
||||
torch::Tensor& payload = work.payload_;
|
||||
auto data = wireDeserialize(payload.storage().data(), payload.numel());
|
||||
auto message = c10::make_intrusive<Message>(
|
||||
std::move(data.first), std::move(data.second), work.type_, work.id_);
|
||||
if (message->isRequest()) {
|
||||
++serverActiveCalls_;
|
||||
c10::intrusive_ptr<JitFuture> futureResponse;
|
||||
try {
|
||||
futureResponse = cb_->operator()(*message, {});
|
||||
} catch (const std::exception& e) {
|
||||
futureResponse = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
|
||||
futureResponse->setError(std::current_exception());
|
||||
}
|
||||
if (futureResponse->completed()) {
|
||||
--serverActiveCalls_;
|
||||
if (!futureResponse->hasError()) {
|
||||
send(work.from_, futureResponse->value().toCustomClass<Message>());
|
||||
} else {
|
||||
send(
|
||||
work.from_,
|
||||
createExceptionResponse(
|
||||
futureResponse->tryRetrieveErrorMessage(), message->id()));
|
||||
}
|
||||
} else {
|
||||
++serverActiveAsyncCalls_;
|
||||
// Callback processing returned an incomplete future. Add sending the
|
||||
// response as a callback which fires when the future completes.
|
||||
auto fromId = work.from_.id_;
|
||||
auto requestId = work.id_;
|
||||
futureResponse->addCallback(
|
||||
[this, fromId, requestId](JitFuture& futureResponse) {
|
||||
--serverActiveCalls_;
|
||||
--serverActiveAsyncCalls_;
|
||||
if (!futureResponse.hasError()) {
|
||||
send(
|
||||
getWorkerInfo(fromId),
|
||||
futureResponse.value().toCustomClass<Message>());
|
||||
} else {
|
||||
send(
|
||||
getWorkerInfo(fromId),
|
||||
createExceptionResponse(
|
||||
futureResponse.tryRetrieveErrorMessage(), requestId));
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (message->isResponse()) {
|
||||
auto id = message->id();
|
||||
c10::intrusive_ptr<JitFuture> jitFuture;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
const auto& futureInfo = futures_.find(id);
|
||||
if (futureInfo == futures_.end()) {
|
||||
// Received a completion for an already-processed future (such as one
|
||||
// that timed out), drop the recv. By returning false, recvCounts will
|
||||
// not be incremented, it will be incremented by the thread that
|
||||
// determined that the future timed out.
|
||||
return false;
|
||||
}
|
||||
// Use futureInfo before destructing it.
|
||||
jitFuture = futureInfo->second.future_;
|
||||
auto endTime = futureInfo->second.endTime_;
|
||||
futures_.erase(id);
|
||||
// look up the corresponding future by its time out and request
|
||||
// ID, and remove it from the timeouts map
|
||||
auto& futuresAtTime = futureTimeouts_[endTime];
|
||||
auto it = futuresAtTime.find(id);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
it != futuresAtTime.end(),
|
||||
"Error: could not find future in futureTimeouts map, race condition.");
|
||||
futuresAtTime.erase(it);
|
||||
if (futuresAtTime.empty()) {
|
||||
// remove the key from futureTimeouts_
|
||||
futureTimeouts_.erase(endTime);
|
||||
}
|
||||
}
|
||||
futureCV_.notify_all();
|
||||
--clientActiveCalls_;
|
||||
if (message->type() == MessageType::EXCEPTION) {
|
||||
jitFuture->setError(std::make_exception_ptr(std::runtime_error(
|
||||
std::string(message->payload().begin(), message->payload().end()))));
|
||||
} else {
|
||||
jitFuture->markCompleted(std::move(message));
|
||||
}
|
||||
} else {
|
||||
// TODO: pass the error back to the caller instead of crashing here.
|
||||
TORCH_INTERNAL_ASSERT(false, "unrecognized message type ", message->type());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::enqueueRecv(RecvWork work) {
|
||||
// NOLINTNEXTLINE(modernize-avoid-bind)
|
||||
threadPool_.run(std::bind(
|
||||
[&](RecvWork& work) {
|
||||
try {
|
||||
// Only increment recvCounts if handleRecv() tells us to. We may not,
|
||||
// i.e. if we process work corresponding to a future that has already
|
||||
// been processed.
|
||||
if (handleRecv(work)) {
|
||||
recvCounts_.increment(work.from_.id_);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
// Processing for this request/response failed. Log the details of the
|
||||
// request.
|
||||
auto fromId = work.from_.id_;
|
||||
auto err = c10::str(
|
||||
"Internal error while processing request of type ",
|
||||
work.type_,
|
||||
" on node ",
|
||||
RpcAgent::getWorkerInfo().id_,
|
||||
", from node ",
|
||||
fromId,
|
||||
" : ",
|
||||
e.what());
|
||||
LOG(INFO) << err;
|
||||
// Still increment so that this recv is recognized as non-oustanding
|
||||
// during graceful shutdown.
|
||||
recvCounts_.increment(work.from_.id_);
|
||||
}
|
||||
},
|
||||
std::move(work)));
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::markFutureWithError(Message& message) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
message.type() == MessageType::EXCEPTION,
|
||||
"markFutureWithError should be only called with Message that has type Exception.");
|
||||
markFutureWithError(
|
||||
message.id(),
|
||||
std::string(message.payload().begin(), message.payload().end()));
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) {
|
||||
c10::intrusive_ptr<JitFuture> jitFuture;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{futureMutex_};
|
||||
const auto& futureInfo = futures_.find(id);
|
||||
|
||||
if (futureInfo == futures_.end()) {
|
||||
// Did not find future in map - this can occur when the future has timed
|
||||
// out and been processed accordingly.
|
||||
return;
|
||||
}
|
||||
jitFuture = futureInfo->second.future_;
|
||||
auto rpcEndTime = futureInfo->second.endTime_;
|
||||
futures_.erase(id);
|
||||
// look up the corresponding future by its time out and request ID,
|
||||
// and remove it from the timeouts map
|
||||
auto& futuresAtTime = futureTimeouts_[rpcEndTime];
|
||||
auto it = futuresAtTime.find(id);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
it != futuresAtTime.end(),
|
||||
"Error: could not find future in futureTimeouts map, race condition.");
|
||||
futuresAtTime.erase(it);
|
||||
if (futuresAtTime.empty()) {
|
||||
// remove the key from futureTimeouts_
|
||||
futureTimeouts_.erase(rpcEndTime);
|
||||
}
|
||||
}
|
||||
|
||||
--clientActiveCalls_;
|
||||
jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg)));
|
||||
futureCV_.notify_all();
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoop() {
|
||||
try {
|
||||
listenLoopInternal();
|
||||
} catch (const std::exception& e) {
|
||||
// Error occured in listenLoop(). Stop receiving thread and store
|
||||
// exception to indicate that the RPC agent is in an unhealthy state and
|
||||
// we should shutdown.
|
||||
auto err = c10::str(
|
||||
"Encountered exception in ProcessGroupAgent::listenLoop(): ",
|
||||
e.what(),
|
||||
" on worker ",
|
||||
RpcAgent::getWorkerInfo().id_,
|
||||
". This means that the RPC agent is in an unhealthy state and unusable.");
|
||||
LOG(ERROR) << err;
|
||||
{
|
||||
// Lock write to listenLoopException_ since ::send() reads from it.
|
||||
std::lock_guard<std::mutex> guard(listenLoopExceptionMutex_);
|
||||
listenLoopException_ = std::current_exception();
|
||||
}
|
||||
} catch (...) {
|
||||
std::string unknownErrorMsg =
|
||||
"Unknown exception occured in "
|
||||
"ProcessGroupAgent::listenLoop. RPC Agent is in an unhealthy state and "
|
||||
"unusable.";
|
||||
LOG(ERROR) << unknownErrorMsg;
|
||||
{
|
||||
// Lock write to listenLoopException_ since ::send() reads from it.
|
||||
std::lock_guard<std::mutex> guard(listenLoopExceptionMutex_);
|
||||
listenLoopException_ =
|
||||
std::make_exception_ptr(std::runtime_error(unknownErrorMsg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::listenLoopInternal() {
|
||||
while (rpcAgentRunning_.load()) {
|
||||
// rank, tensor size, message type
|
||||
std::vector<torch::Tensor> preamble = {torch::empty({4}, {torch::kInt64})};
|
||||
auto work = pg_->recvAnysource(preamble, pg_->getRank());
|
||||
{
|
||||
// Write class variable so it can be aborted by shutdown()
|
||||
std::lock_guard<std::mutex> guard(recvWorkMutex_);
|
||||
recvWork_ = work;
|
||||
}
|
||||
|
||||
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
|
||||
|
||||
auto srcRank = preamble_items[0];
|
||||
auto size = preamble_items[1];
|
||||
MessageType type = MessageType(preamble_items[2]);
|
||||
int64_t id = preamble_items[3];
|
||||
|
||||
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
|
||||
work = pg_->recv(tensors, srcRank, pg_->getRank());
|
||||
{
|
||||
// Write class variable so it can be aborted by shutdown()
|
||||
std::lock_guard<std::mutex> guard(recvWorkMutex_);
|
||||
recvWork_ = work;
|
||||
}
|
||||
|
||||
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
|
||||
return;
|
||||
}
|
||||
|
||||
enqueueRecv(
|
||||
RecvWork(allWorkerInfo_[srcRank], type, id, std::move(tensors[0])));
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::pollTimedOutRPCs() {
|
||||
while (timeoutThreadEnabled_.load()) {
|
||||
std::unique_lock<std::mutex> lock{futureMutex_};
|
||||
steady_clock_time_point minEndTime;
|
||||
// Estimate amount of time the first future will time out in, and sleep
|
||||
// for that long.
|
||||
// if there are no futures or the first future's RPC timeout is set to 0
|
||||
// (meaning no timeout), then sleep for a set "infinity" time.
|
||||
if (futureTimeouts_.empty()) {
|
||||
minEndTime = kInfiniteTimeoutTimePoint;
|
||||
} else {
|
||||
minEndTime = futureTimeouts_.begin()->first;
|
||||
}
|
||||
|
||||
auto shouldUpdateMinEndTimePredicate = [&, this]() -> bool {
|
||||
// Notice, whoever modifies `timeoutThreadEnabled_`
|
||||
// must acquire a lock on `futureMutex_`.
|
||||
// Otherwise, this predicate could deadlock.
|
||||
// If during evaluating the predicate, `::shutdown()` is called, then
|
||||
// the predicate missed the notification before it started waiting
|
||||
// on the cond var.
|
||||
if (!timeoutThreadEnabled_.load()) {
|
||||
return true;
|
||||
}
|
||||
steady_clock_time_point minEndTimeInMap = kInfiniteTimeoutTimePoint;
|
||||
if (futureTimeouts_.empty()) {
|
||||
minEndTimeInMap = kInfiniteTimeoutTimePoint;
|
||||
} else {
|
||||
minEndTimeInMap = futureTimeouts_.begin()->first;
|
||||
}
|
||||
return minEndTimeInMap < minEndTime;
|
||||
};
|
||||
|
||||
bool shouldUpdateMinEndTime = true;
|
||||
if (minEndTime == kInfiniteTimeoutTimePoint) {
|
||||
futureTimeoutCV_.wait(lock, shouldUpdateMinEndTimePredicate);
|
||||
} else {
|
||||
shouldUpdateMinEndTime = futureTimeoutCV_.wait_until(
|
||||
lock, minEndTime, shouldUpdateMinEndTimePredicate);
|
||||
}
|
||||
if (shouldUpdateMinEndTime) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto timedOutFutures = processTimedOutFutures();
|
||||
lock.unlock();
|
||||
futureCV_.notify_all();
|
||||
|
||||
for (const auto& timedOutFuture : timedOutFutures) {
|
||||
auto errStr =
|
||||
fmt::format(kRpcTimeoutErrorStr, timedOutFuture.timeout_.count());
|
||||
auto err = makeRPCError(errStr, RPCErrorType::TIMEOUT);
|
||||
|
||||
if (!timedOutFuture.future_->hasError()) {
|
||||
--clientActiveCalls_;
|
||||
timedOutFuture.future_->setError(
|
||||
std::make_exception_ptr(std::runtime_error(err)));
|
||||
// The future timed out and will not be processed by handleRecv(), even
|
||||
// if we eventually get a response. In order to keep track of all
|
||||
// send/recv pairs, we increment the count here.
|
||||
const int dst = timedOutFuture.dstRank_;
|
||||
recvCounts_.increment(dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<ProcessGroupAgent::FutureInfo> ProcessGroupAgent::
|
||||
processTimedOutFutures() {
|
||||
std::vector<FutureInfo> timedOutFutures;
|
||||
for (auto it = futureTimeouts_.begin(); it != futureTimeouts_.end();
|
||||
/* intentional no increment */) {
|
||||
const auto& endTime = it->first;
|
||||
if (std::chrono::steady_clock::now() < endTime) {
|
||||
// Since the futureTimeouts_ map is ordered by timeout, we don't need
|
||||
// to check the remaining futures.
|
||||
break;
|
||||
} else {
|
||||
const auto& futureIDs = it->second;
|
||||
for (const auto& futureID : futureIDs) {
|
||||
auto futureIt = futures_.find(futureID);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
futureIt != futures_.end(),
|
||||
"Race Condition - Expected future does not exist in map");
|
||||
const auto futInfo = futureIt->second;
|
||||
timedOutFutures.push_back(futInfo);
|
||||
futures_.erase(futureID);
|
||||
}
|
||||
it = futureTimeouts_.erase(it);
|
||||
}
|
||||
}
|
||||
return timedOutFutures;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> ProcessGroupAgent::getMetrics() {
|
||||
std::unordered_map<std::string, std::string> metrics;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(futureMutex_);
|
||||
auto futuresSize = futures_.size();
|
||||
lock.unlock();
|
||||
metrics[kNumPendingRequests] = c10::to_string(futuresSize);
|
||||
}
|
||||
metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
|
||||
metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
|
||||
metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_.load());
|
||||
metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_.load());
|
||||
metrics[kServerActiveAsyncCalls] =
|
||||
c10::to_string(serverActiveAsyncCalls_.load());
|
||||
if (isGILProfilingEnabled()) {
|
||||
// Add time-series based metrics, just GIL wait times for now.
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(metricsMutex_);
|
||||
auto avgGilWaitTime = metrics_[GIL_WAIT_TIME]->computeAverage();
|
||||
lock.unlock();
|
||||
metrics[kGilAverageWaitTime] = c10::to_string(avgGilWaitTime);
|
||||
}
|
||||
}
|
||||
return metrics;
|
||||
}
|
||||
|
||||
void ProcessGroupAgent::addGilWaitTime(
|
||||
const std::chrono::microseconds gilWaitTime) {
|
||||
std::lock_guard<std::mutex> lock(metricsMutex_);
|
||||
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME]->addData(
|
||||
gilWaitTime.count());
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -1,290 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/thread_pool.h>
|
||||
#include <c10d/PrefixStore.hpp>
|
||||
#include <c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/rpc/request_callback.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_agent.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
constexpr auto kDefaultNumSendRecvThreads = 4;
|
||||
|
||||
struct TORCH_API ProcessGroupRpcBackendOptions : public RpcBackendOptions {
|
||||
ProcessGroupRpcBackendOptions(
|
||||
int num_send_recv_threads,
|
||||
float rpc_timeout,
|
||||
std::string init_method)
|
||||
: RpcBackendOptions(rpc_timeout, init_method),
|
||||
numSendRecvThreads(num_send_recv_threads) {
|
||||
TORCH_CHECK(
|
||||
num_send_recv_threads > 0,
|
||||
"Cannot create ProcessGroup RPC backend with ",
|
||||
num_send_recv_threads,
|
||||
" threads in the thread-pool.");
|
||||
}
|
||||
|
||||
int numSendRecvThreads;
|
||||
};
|
||||
|
||||
// SendWork and RecvWork will be put into a task queue, and later picked up by
|
||||
// worker threads from the same ThreadPool.
|
||||
struct TORCH_API SendWork {
|
||||
SendWork(const WorkerInfo& to, c10::intrusive_ptr<Message> message)
|
||||
: to_(to), message_(std::move(message)) {}
|
||||
|
||||
const WorkerInfo& to_;
|
||||
c10::intrusive_ptr<Message> message_;
|
||||
};
|
||||
|
||||
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
|
||||
// to allow us to run serialization/deserialization in the worker threads.
|
||||
struct TORCH_API RecvWork {
|
||||
RecvWork(
|
||||
const WorkerInfo& from,
|
||||
MessageType type,
|
||||
int64_t id,
|
||||
torch::Tensor&& payload)
|
||||
: from_(from), type_(type), id_(id), payload_(payload) {}
|
||||
|
||||
const WorkerInfo& from_;
|
||||
const MessageType type_;
|
||||
const int64_t id_;
|
||||
torch::Tensor payload_;
|
||||
};
|
||||
|
||||
class TORCH_API ProcessGroupAgent : public RpcAgent {
|
||||
public:
|
||||
ProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
std::chrono::milliseconds rpcTimeout,
|
||||
std::unique_ptr<RequestCallback> cb);
|
||||
|
||||
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
|
||||
|
||||
const WorkerInfo& getWorkerInfo(worker_id_t id) const override;
|
||||
|
||||
std::vector<WorkerInfo> getWorkerInfos() const override;
|
||||
|
||||
void join(bool shutdown = false) override;
|
||||
|
||||
void sync() override;
|
||||
|
||||
void startImpl() override;
|
||||
|
||||
void shutdownImpl() override;
|
||||
|
||||
~ProcessGroupAgent() override;
|
||||
|
||||
std::unordered_map<std::string, std::string> getMetrics() override;
|
||||
|
||||
protected:
|
||||
// This method wraps the destination information and the message into a
|
||||
// SendWork object, and put the SendWork into a queue. Another thread will
|
||||
// consume SendWork from the queue and send it out.
|
||||
c10::intrusive_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
c10::intrusive_ptr<Message> message,
|
||||
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
|
||||
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
|
||||
override;
|
||||
|
||||
// put SendWork into a queue and notify the worker thread
|
||||
virtual void enqueueSend(SendWork work);
|
||||
// Bypass handleSend() logic and send a message to self rank
|
||||
virtual void sendToSelf(c10::intrusive_ptr<Message> message);
|
||||
|
||||
private:
|
||||
class MessageCounter {
|
||||
public:
|
||||
explicit MessageCounter(int worldSize);
|
||||
void increment(int dst);
|
||||
std::vector<int64_t> snapshot();
|
||||
|
||||
private:
|
||||
std::vector<int64_t> counters_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
// TODO: this class should inherit from a MetricsTracker, and can be extended
|
||||
// to track num_sends, recvs, average size of messages, etc.
|
||||
struct AverageMetricsTracker {
|
||||
std::string key_;
|
||||
uint64_t currentSum_;
|
||||
uint64_t currentCount_;
|
||||
|
||||
explicit AverageMetricsTracker(
|
||||
std::string key,
|
||||
uint64_t currentSum = 0,
|
||||
uint64_t currentCount = 0);
|
||||
|
||||
void addData(uint64_t dataPoint);
|
||||
double computeAverage();
|
||||
};
|
||||
|
||||
// The FutureInfo struct stores a shared_ptr to the future, as well as
|
||||
// additional information to manage timeouts and destination information,
|
||||
// which is needed for termination detection.
|
||||
struct FutureInfo {
|
||||
c10::intrusive_ptr<JitFuture> future_;
|
||||
steady_clock_time_point endTime_;
|
||||
int dstRank_;
|
||||
std::chrono::milliseconds timeout_;
|
||||
FutureInfo(
|
||||
c10::intrusive_ptr<JitFuture> future,
|
||||
const steady_clock_time_point& endTime,
|
||||
int dstRank,
|
||||
const std::chrono::milliseconds timeout)
|
||||
: future_(std::move(future)),
|
||||
endTime_(endTime),
|
||||
dstRank_(dstRank),
|
||||
timeout_(timeout) {}
|
||||
FutureInfo() = delete;
|
||||
};
|
||||
|
||||
// handle a SendWork request. This serializes the payload inside the work
|
||||
// object, and sends the message to the receiver using the underlying
|
||||
// ProcessGroup.
|
||||
void handleSend(const SendWork& work);
|
||||
// put RecvWork into a queue and notify the worker thread
|
||||
void enqueueRecv(RecvWork work);
|
||||
// handle a RecvWork request. Return true if we should increment recvCounts,
|
||||
// false if not (i.e. if the RPC timed out and we are getting a result after
|
||||
// the timeout). This ensures that the messages accounted for in
|
||||
// hasPendingMessage() are tallied properly during a graceful shutdown.
|
||||
bool handleRecv(RecvWork& work);
|
||||
// Loop that receives and processes messages
|
||||
void listenLoopInternal();
|
||||
// Calls listenLoopInternal and handles errors such as timeouts on the
|
||||
// process group.
|
||||
void listenLoop();
|
||||
// exception_pointer correspnding to an exception raised in listenLoop (if
|
||||
// there is one), and lock to guard access.
|
||||
std::exception_ptr listenLoopException_;
|
||||
std::mutex listenLoopExceptionMutex_;
|
||||
// poll for timed out RPCs
|
||||
void pollTimedOutRPCs();
|
||||
// process timed out futures
|
||||
const std::vector<FutureInfo> processTimedOutFutures();
|
||||
// compute the remaining time for an RPC, given its end time.
|
||||
const std::chrono::milliseconds getRPCRemainingTime(
|
||||
const std::chrono::milliseconds& rpcEndTime) const;
|
||||
|
||||
// a helper function to mark a future in the futures_ map with a message. The
|
||||
// future is marked with the passed in message, and then removed from the
|
||||
// futures_ map. It is also removed from the futureTimeouts_ map since these
|
||||
// maps are kept in sync.
|
||||
void markFutureWithError(Message& message);
|
||||
void markFutureWithError(int64_t id, std::string errorMsg);
|
||||
|
||||
// Note [Termination Detection]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
//
|
||||
// RpcAgent implementations must properly detect termination. Otherwise, it
|
||||
// would result in message loss, RRef leak, or process hang. It is not
|
||||
// sufficient to just wait for the thread pool to finish processing all tasks
|
||||
// after all processes hit the join function. There could be nested rpc/remote
|
||||
// calls, meaning that an empty task queue in the thread pool does not mean
|
||||
// there will be no tasks added in the future. Moreover, in the listenLoop,
|
||||
// there is a period of time when the message has been received but not yet
|
||||
// inserted into the thread pool, which also suggests that the empty task
|
||||
// queue is not a good indicator for termination.
|
||||
//
|
||||
// To detect termination, each ProcessGroupAgent maintains a sent message
|
||||
// counter and a received message counter. The sent message counter is
|
||||
// incremented whenever a message is sent, and the receive message counter is
|
||||
// only incremented when a message has been processed. During termination, all
|
||||
// ProcessGroupAgent instances run an allgather to collect counters from all
|
||||
// peers, which means that all agents will have a consistent view on the
|
||||
// message count snapshot. They would only terminate if all sent/received
|
||||
// message counters match.
|
||||
bool hasPendingMessage();
|
||||
|
||||
int64_t nextId() {
|
||||
return ++nextId_;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
|
||||
// worker name -> rank
|
||||
std::unordered_map<std::string, worker_id_t> nameMap_;
|
||||
std::vector<WorkerInfo> allWorkerInfo_;
|
||||
// record the number of messages sent to and received from each peer. The recv
|
||||
// counter is only marked after the message is processed. Join uses allgather
|
||||
// to collect all counts from all peers, uses these counters to detect global
|
||||
// termination and only exit when all sent messages are processed.
|
||||
MessageCounter sendCounts_;
|
||||
MessageCounter recvCounts_;
|
||||
|
||||
std::atomic<int64_t> nextId_;
|
||||
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
|
||||
// when using the same tag.
|
||||
std::vector<std::mutex> sendMutexes_;
|
||||
std::thread listenerThread_;
|
||||
// A thread to poll existing futures and check for timed out ones.
|
||||
std::thread futureTimeoutThread_;
|
||||
// Lock and shared ptr to currently pending work, set in listenloop() and
|
||||
// interruptible in shutdown().
|
||||
std::mutex recvWorkMutex_;
|
||||
c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;
|
||||
// Map of dst rank to current oustanding sends that we are waiting on. In the
|
||||
// case of a call to ::shutdown() while we are still waiting on these sends,
|
||||
// the pending sends contained in this map will be aborted, allowing the
|
||||
// waiting thread to be unblocked.
|
||||
std::unordered_map<
|
||||
worker_id_t,
|
||||
std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
|
||||
currentPendingSends_;
|
||||
// Lock to serialize access to the above map.
|
||||
std::mutex pendingSendMutex_;
|
||||
// A threadPool that processing both SendWork and RecvWork. There are two
|
||||
// motivations for adding a ThreadPool:
|
||||
// (1) RPC serialization/deserialization and processing can be expensive,
|
||||
// hence using multiple threads to speed it up.
|
||||
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
|
||||
// not yield in the middle of execution to wait for IO, and resume the IO
|
||||
// is done. This would result in deadlocks when we have nested RPC calls.
|
||||
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
|
||||
// This is just a temporary solution for (2).
|
||||
ThreadPool threadPool_;
|
||||
// Atomic to indicate whether the timeout thread is enabled.
|
||||
std::atomic<bool> timeoutThreadEnabled_;
|
||||
// Mapping of request id to FutureInfo struct.
|
||||
std::unordered_map<int64_t, FutureInfo> futures_;
|
||||
// A map to keep track of when futures time out. The map is keyed by the time
|
||||
// (millisecond level precision) the future will expire. This is so that timed
|
||||
// out futures can be efficiently cleaned up, and we can quickly exit if we
|
||||
// find a future that has not timed out. The values correspond to an
|
||||
// unordered_set of future ids that started at that time. This map must be
|
||||
// kept in sync with the above futures_ map.
|
||||
std::map<steady_clock_time_point, std::unordered_set<int64_t>>
|
||||
futureTimeouts_;
|
||||
mutable std::mutex futureMutex_;
|
||||
mutable std::condition_variable futureCV_;
|
||||
// CV to wake up watchdog thread that watches for timed out futures.
|
||||
std::condition_variable futureTimeoutCV_;
|
||||
// Metrics tracked for ProcessGroupAgent.
|
||||
enum ProcessGroupAgentMetrics {
|
||||
GIL_WAIT_TIME = 0,
|
||||
|
||||
N_METRICS,
|
||||
};
|
||||
std::mutex metricsMutex_;
|
||||
std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
|
||||
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
|
||||
|
||||
std::atomic<int32_t> clientActiveCalls_{0};
|
||||
std::atomic<int32_t> serverActiveCalls_{0};
|
||||
std::atomic<int32_t> serverActiveAsyncCalls_{0};
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -1,152 +0,0 @@
|
||||
#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
|
||||
#include <torch/csrc/distributed/rpc/utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
std::string fromVec(const std::vector<char>& vec) {
|
||||
return std::string(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
std::chrono::milliseconds rpcTimeout,
|
||||
std::unique_ptr<RequestCallback> cb,
|
||||
const std::vector<std::string>& messagesToFail,
|
||||
const std::unordered_map<std::string, float>& messageTypesToDelay,
|
||||
int failNumSends)
|
||||
: ProcessGroupAgent(
|
||||
store,
|
||||
std::move(workerName),
|
||||
std::move(pg),
|
||||
numSendRecvThreads,
|
||||
rpcTimeout,
|
||||
std::move(cb)),
|
||||
failNumSends_(failNumSends),
|
||||
messageTypesToFail_(parseMessagesToFailInput(messagesToFail)),
|
||||
messageTypesToDelay_(parseMessagesToDelay(messageTypesToDelay)) {}
|
||||
|
||||
std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
|
||||
const std::vector<std::string>& messagesToFail) const {
|
||||
// Since we can only pass strings corresponding to the Message Types from the
|
||||
// python tests, we must parse the list of strings and resolve the actual
|
||||
// types. We will then check this list of types in the send function to
|
||||
// determine whether we should fail or not.
|
||||
std::vector<MessageType> messageTypesToFail;
|
||||
messageTypesToFail.reserve(messagesToFail.size());
|
||||
for (const auto& msgString : messagesToFail) {
|
||||
messageTypesToFail.push_back(messageStringToType(msgString));
|
||||
}
|
||||
return messageTypesToFail;
|
||||
}
|
||||
|
||||
std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
|
||||
parseMessagesToDelay(const std::unordered_map<std::string, float>&
|
||||
messageTypesToDelay) const {
|
||||
std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
|
||||
for (const auto& messagePair : messageTypesToDelay) {
|
||||
float delay = messagePair.second;
|
||||
TORCH_CHECK(
|
||||
delay >= 0,
|
||||
"Delays passed to FaultyProcessGroupAgent must be non-negative.")
|
||||
delayMessages.insert({messageStringToType(messagePair.first), delay});
|
||||
}
|
||||
return delayMessages;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<JitFuture> FaultyProcessGroupAgent::send(
|
||||
const WorkerInfo& to,
|
||||
c10::intrusive_ptr<Message> message,
|
||||
const float rpcTimeoutSeconds,
|
||||
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
|
||||
// We only fail control messages that have been specified by the test case.
|
||||
// For all other messages, we just send them without any failures.
|
||||
if (!shouldFailMessage(message->type())) {
|
||||
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
|
||||
}
|
||||
// This send function checks the failMessageCountMap_ to check whether
|
||||
// we must fail the next send. If the send must be failed, we set an error
|
||||
// on the returned future immediately and increment the counter in the map,
|
||||
// otherwise we just call the ProcessGroupAgent send.
|
||||
const auto key = fromVec(message->payload());
|
||||
std::unique_lock<std::mutex> lock(failMapMutex_);
|
||||
auto it = failMessageCountMap_.find(key);
|
||||
if (it == failMessageCountMap_.end()) {
|
||||
failMessageCountMap_[key] = 0;
|
||||
}
|
||||
if (failMessageCountMap_[key] < failNumSends_) {
|
||||
failMessageCountMap_[key]++;
|
||||
lock.unlock();
|
||||
auto jitFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
|
||||
jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError(
|
||||
c10::str("Send attempt failed intentionally for ", key),
|
||||
RPCErrorType::INTENTIONAL_FAILURE))));
|
||||
return jitFuture;
|
||||
} else {
|
||||
lock.unlock();
|
||||
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
|
||||
}
|
||||
}
|
||||
|
||||
void FaultyProcessGroupAgent::enqueueSend(SendWork work) {
|
||||
float msgDelay = getDelayForMessage(work.message_->type());
|
||||
if (msgDelay != 0) {
|
||||
// Sleep for the specified delay for the message.
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(
|
||||
static_cast<int>(msgDelay * kSecToMsConversion)));
|
||||
}
|
||||
ProcessGroupAgent::enqueueSend(std::move(work));
|
||||
}
|
||||
|
||||
void FaultyProcessGroupAgent::sendToSelf(c10::intrusive_ptr<Message> message) {
|
||||
float msgDelay = getDelayForMessage(message->type());
|
||||
if (msgDelay != 0) {
|
||||
// Sleep for the specified delay for the message.
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(
|
||||
static_cast<int>(msgDelay * kSecToMsConversion)));
|
||||
}
|
||||
ProcessGroupAgent::sendToSelf(std::move(message));
|
||||
}
|
||||
|
||||
bool FaultyProcessGroupAgent::shouldFailMessage(MessageType type) const {
|
||||
// Return true if the input message type is in the messageTypesToFail_ list
|
||||
return (
|
||||
std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
|
||||
messageTypesToFail_.end());
|
||||
}
|
||||
|
||||
float FaultyProcessGroupAgent::getDelayForMessage(MessageType type) const {
|
||||
const auto& it = messageTypesToDelay_.find(type);
|
||||
return it == messageTypesToDelay_.end() ? 0 : it->second;
|
||||
}
|
||||
|
||||
MessageType FaultyProcessGroupAgent::messageStringToType(
|
||||
const std::string& messageString) const {
|
||||
// Lazily constructed map that returns string to message type mapping
|
||||
static std::unordered_map<std::string, MessageType> msgMap = {
|
||||
{"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
|
||||
{"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
|
||||
{"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
|
||||
{"CLEANUP_AUTOGRAD_CONTEXT_REQ",
|
||||
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
|
||||
{"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
|
||||
{"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
|
||||
{"PYTHON_CALL", MessageType::PYTHON_CALL},
|
||||
{"SCRIPT_CALL", MessageType::SCRIPT_CALL},
|
||||
{"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
|
||||
{"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
|
||||
const auto& it = msgMap.find(messageString);
|
||||
TORCH_CHECK(
|
||||
it != msgMap.end(),
|
||||
"No mapping to rpc::MessageType exists for ",
|
||||
messageString);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
@ -1,99 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/process_group_agent.h>
|
||||
|
||||
namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
struct TORCH_API FaultyProcessGroupRpcBackendOptions
|
||||
: public ProcessGroupRpcBackendOptions {
|
||||
FaultyProcessGroupRpcBackendOptions(
|
||||
int num_send_recv_threads,
|
||||
float rpc_timeout,
|
||||
std::string init_method,
|
||||
std::vector<std::string> messages_to_fail,
|
||||
std::unordered_map<std::string, float> messages_to_delay,
|
||||
int num_fail_sends = 0)
|
||||
: ProcessGroupRpcBackendOptions(
|
||||
num_send_recv_threads,
|
||||
rpc_timeout,
|
||||
std::move(init_method)),
|
||||
messagesToFail(std::move(messages_to_fail)),
|
||||
messagesToDelay(std::move(messages_to_delay)),
|
||||
numFailSends(num_fail_sends) {
|
||||
TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
|
||||
}
|
||||
|
||||
std::vector<std::string> messagesToFail;
|
||||
std::unordered_map<std::string, float> messagesToDelay;
|
||||
int numFailSends;
|
||||
};
|
||||
|
||||
class TORCH_API FaultyProcessGroupAgent : public ProcessGroupAgent {
|
||||
public:
|
||||
FaultyProcessGroupAgent(
|
||||
const c10::intrusive_ptr<::c10d::Store>& store,
|
||||
std::string workerName,
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
int numSendRecvThreads,
|
||||
std::chrono::milliseconds rpcTimeout,
|
||||
std::unique_ptr<RequestCallback> cb,
|
||||
const std::vector<std::string>& messagesToFail,
|
||||
const std::unordered_map<std::string, float>& messageTypesToDelay,
|
||||
int failNumSends = 0);
|
||||
|
||||
// Faulty send function for this class.
|
||||
c10::intrusive_ptr<JitFuture> send(
|
||||
const WorkerInfo& to,
|
||||
c10::intrusive_ptr<Message> message,
|
||||
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
|
||||
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
|
||||
override;
|
||||
|
||||
protected:
|
||||
// This function checks the messageTypesToFail_ to determine whether to use
|
||||
// the faulty send or not.
|
||||
virtual bool shouldFailMessage(MessageType type) const;
|
||||
|
||||
private:
|
||||
// Overrides ProcessGroupAgent's enqueueSend to inject delays.
|
||||
void enqueueSend(SendWork work) override;
|
||||
// Override ProcessGroupAgent's sendToSelf to inject delays.
|
||||
void sendToSelf(c10::intrusive_ptr<Message> message) override;
|
||||
// This function parses the list of strings passed in by the python tests and
|
||||
// resolves the Message Types that must use the faulty send.
|
||||
std::vector<MessageType> parseMessagesToFailInput(
|
||||
const std::vector<std::string>& messagesToFail) const;
|
||||
|
||||
// Returns amount of time in seconds to delay sending of the given message
|
||||
// type.
|
||||
float getDelayForMessage(MessageType type) const;
|
||||
|
||||
// Parse message types that we should inject arbitrary delays for.
|
||||
std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
|
||||
const std::unordered_map<std::string, float>& messageTypesToDelay) const;
|
||||
|
||||
// Number of sends to intentionally fail before allowing one to succeed.
|
||||
const int failNumSends_;
|
||||
|
||||
// Vector of the MessageTypes that we must use the faulty send for. This is
|
||||
// parsed based on a list of strings passed in by the python tests.
|
||||
const std::vector<MessageType> messageTypesToFail_;
|
||||
|
||||
// Mapping of message types to amount we should delay send for in the ::send()
|
||||
// function.
|
||||
std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
|
||||
|
||||
// Map to track the number of sends we've failed for each RPC.
|
||||
std::unordered_map<std::string, int> failMessageCountMap_;
|
||||
|
||||
// Mutex to guard failMessageCountMap_
|
||||
std::mutex failMapMutex_;
|
||||
|
||||
MessageType messageStringToType(const std::string& messageString) const;
|
||||
};
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
} // namespace torch
|
Reference in New Issue
Block a user