Remove process_group_agent and faulty_process_group_agent files (#62985)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62985

Remove the process_group_agent and faulty_process_group_agent code now that PROCESS_GROUP backend has been deprecated for RPC (https://github.com/pytorch/pytorch/issues/55615). Discussed with xush6528 that it was okay to remove ProcessGroupAgentTest and ProcessGroupAgentBench which depended on process_group_agent.

Test Plan: CI tests

Reviewed By: pritamdamania87

Differential Revision: D30195576

fbshipit-source-id: 8b4381cffadb868b19d481198015d0a67b205811
This commit is contained in:
Howard Huang
2021-08-10 15:56:18 -07:00
committed by Facebook GitHub Bot
parent 790553811c
commit 4d0497034c
7 changed files with 0 additions and 1454 deletions

View File

@ -10,7 +10,6 @@ set(TORCH_RPC_TEST_DEPENDENCY_LIBS
if(USE_GLOO)
list(APPEND TORCH_RPC_TEST_SOURCES
${TORCH_RPC_TEST_DIR}/test_e2e_process_group.cpp
${TORCH_RPC_TEST_DIR}/test_e2e_tensorpipe.cpp
)
endif()

View File

@ -1,47 +0,0 @@
#include <gtest/gtest.h>
#include "e2e_test_base.h"
#include <c10d/ProcessGroupGloo.hpp>
#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
#include <torch/torch.h>
namespace torch {
namespace distributed {
namespace rpc {
class TestE2EProcessGroup : public TestE2EBase {
protected:
void buildRpcAgent() override {
auto options = c10d::ProcessGroupGloo::Options::create();
options->devices.push_back(
::c10d::ProcessGroupGloo::createDeviceForHostname(serverAddress));
std::chrono::milliseconds rpcTimeout(30000);
options->timeout = rpcTimeout;
// Initialize server rpc agent.
auto pg = c10::make_intrusive<c10d::ProcessGroupGloo>(
store, 0, numWorkers, options);
rpcAgent = std::make_shared<ProcessGroupAgent>(
store,
"worker",
pg,
std::max(16U, std::thread::hardware_concurrency()),
rpcTimeout,
std::make_unique<RequestCallbackNoPython>());
}
};
// End to end training loop test in C++ so that we can run LSAN on this test to
// catch memory leaks. Enabling LSAN with python multiprocessing has been
// challenging and we don't have a good solution yet.
TEST_F(TestE2EProcessGroup, TestTrainingLoop) {
runTrainingLoop();
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -371,7 +371,6 @@ libtorch_distributed_extra_sources = [
"torch/csrc/distributed/rpc/python_call.cpp",
"torch/csrc/distributed/rpc/python_remote_call.cpp",
"torch/csrc/distributed/rpc/python_resp.cpp",
"torch/csrc/distributed/rpc/process_group_agent.cpp",
"torch/csrc/distributed/rpc/request_callback.cpp",
"torch/csrc/distributed/rpc/request_callback_no_python.cpp",
"torch/csrc/distributed/rpc/rpc_agent.cpp",
@ -383,7 +382,6 @@ libtorch_distributed_extra_sources = [
"torch/csrc/distributed/rpc/script_resp.cpp",
"torch/csrc/distributed/rpc/tensorpipe_agent.cpp",
"torch/csrc/distributed/rpc/tensorpipe_utils.cpp",
"torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp",
"torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp",
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
"torch/csrc/distributed/rpc/types.cpp",

View File

@ -1,863 +0,0 @@
#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <c10/util/C++17.h>
#include <c10/util/irange.h>
#include <c10d/ProcessGroup.hpp>
#include <fmt/format.h>
#include <torch/csrc/distributed/rpc/agent_utils.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
////////////////////////// MessageCounter /////////////////////////////////
ProcessGroupAgent::MessageCounter::MessageCounter(int worldSize)
: counters_(worldSize) {}
void ProcessGroupAgent::MessageCounter::increment(int dst) {
std::lock_guard<std::mutex> guard(mutex_);
++counters_[dst];
}
std::vector<int64_t> ProcessGroupAgent::MessageCounter::snapshot() {
std::lock_guard<std::mutex> guard(mutex_);
return counters_;
}
////////////////////////// MetricsTracker /////////////////////////////////
ProcessGroupAgent::AverageMetricsTracker::AverageMetricsTracker(
std::string key,
uint64_t currentSum,
uint64_t currentCount)
: key_(std::move(key)),
currentSum_(currentSum),
currentCount_(currentCount) {}
void ProcessGroupAgent::AverageMetricsTracker::addData(uint64_t dataPoint) {
currentSum_ += dataPoint;
++currentCount_;
}
double ProcessGroupAgent::AverageMetricsTracker::computeAverage() {
return currentCount_ == 0 ? 0 : currentSum_ / (double)currentCount_;
}
//////////////////////// ProcessGroupAgent /////////////////////////////////
using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;
const steady_clock_time_point kInfiniteTimeoutTimePoint =
std::chrono::time_point<std::chrono::steady_clock>::max();
const std::string kNumPendingRequests = "agent.num_pending_requests";
const std::string kThreadPoolSize = "agent.thread_pool_size";
const std::string kNumIdleThreads = "agent.num_idle_threads";
const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
const std::string kClientActiveCalls = "agent.client_active_calls";
const std::string kServerActiveCalls = "agent.server_active_calls";
const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
ProcessGroupAgent::ProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb)
: RpcAgent(
WorkerInfo(std::move(workerName), (int64_t)pg->getRank()),
std::move(cb),
rpcTimeout),
pg_(std::move(pg)),
sendCounts_(pg_->getSize()),
recvCounts_(pg_->getSize()),
nextId_(0),
sendMutexes_(pg_->getSize()),
threadPool_(numSendRecvThreads),
timeoutThreadEnabled_{false} {
// initialize metric info counters
metrics_.resize(ProcessGroupAgentMetrics::N_METRICS);
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME] =
std::make_unique<AverageMetricsTracker>(kGilAverageWaitTime);
nameMap_ = collectNames(
::c10d::PrefixStore("names", store),
workerInfo_.id_,
workerInfo_.name_,
pg_->getSize());
auto workerRankIter = nameMap_.find(workerInfo_.name_);
TORCH_CHECK(
workerRankIter != nameMap_.end(),
"Failed to resolve worker "
"name ",
workerInfo_.name_,
" to a ProcessGroup rank.");
TORCH_CHECK(
pg_->getRank() == workerRankIter->second,
"Resolved worker rank ",
workerRankIter->second,
" does not match ProcessGroup rank ",
pg_->getRank());
// tmp vector to sort names in rank's order
const auto worldSize = pg_->getSize();
std::vector<std::string> tmpWorkerIds(worldSize);
for (auto& entry : nameMap_) {
tmpWorkerIds[entry.second] = entry.first;
}
allWorkerInfo_.reserve(worldSize);
for (worker_id_t rank = 0; rank < worldSize; ++rank) {
allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank);
}
}
ProcessGroupAgent::~ProcessGroupAgent() {
if (rpcAgentRunning_) {
shutdown();
}
}
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(
const std::string& workerName) const {
const auto idIter = nameMap_.find(workerName);
TORCH_CHECK(
idIter != nameMap_.end(), "Unknown destination worker ", workerName);
return allWorkerInfo_[idIter->second];
}
const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const {
TORCH_CHECK(
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
id >= 0 && id < allWorkerInfo_.size(),
"Invalid destination: ",
id);
return allWorkerInfo_[id];
}
std::vector<WorkerInfo> ProcessGroupAgent::getWorkerInfos() const {
return allWorkerInfo_;
}
void ProcessGroupAgent::join(bool /* unused */) {
sync();
std::unique_lock<std::mutex> lock(futureMutex_);
futureCV_.wait(
lock, [this] { return futures_.empty() && futureTimeouts_.empty(); });
lock.unlock();
pg_->barrier()->wait();
}
bool ProcessGroupAgent::hasPendingMessage() {
const auto worldSize = pg_->getSize();
auto snapshot = std::make_unique<std::vector<int64_t>>();
snapshot->reserve(2 * worldSize);
auto recvSnapshot = recvCounts_.snapshot();
auto sendSnapshot = sendCounts_.snapshot();
snapshot->insert(
snapshot->end(),
std::make_move_iterator(recvSnapshot.begin()),
std::make_move_iterator(recvSnapshot.end()));
snapshot->insert(
snapshot->end(),
std::make_move_iterator(sendSnapshot.begin()),
std::make_move_iterator(sendSnapshot.end()));
auto snapshotData = snapshot->data();
auto deleteWhenDone = snapshot.release();
std::vector<torch::Tensor> inputSnapshot = {torch::from_blob(
snapshotData,
{2, worldSize},
[deleteWhenDone](void*) { delete deleteWhenDone; },
{torch::kInt64})};
// allgather both send and recv messages in one shot
std::vector<std::vector<torch::Tensor>> outputSnapshots(1);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores,clang-diagnostic-unused-variable)
for (const auto i : c10::irange(worldSize)) {
outputSnapshots[0].emplace_back(
torch::zeros({2, worldSize}, {torch::kInt64}));
}
pg_->allgather(outputSnapshots, inputSnapshot)->wait();
// loop through all send/recv pairs to make sure that all sent messages are
// processed.
const auto& peerCounts = outputSnapshots[0];
for (const auto from : c10::irange(worldSize)) {
for (const auto to : c10::irange(worldSize)) {
// peerCounts[x][0] is recv counts, and peerCounts[x][1] is send counts
const auto& sentCnt = peerCounts[from][1][to].data_ptr<int64_t>()[0];
const auto& recvCnt = peerCounts[to][0][from].data_ptr<int64_t>()[0];
// NB: we cannot throw an error when sentCnt < recvCnt here. Because, send
// and recv counts on different workers are read in a distributed manner.
// It is possible that the sender reads its send count before sending, but
// the receive reads its recv count after receiving. Hence, both > and <
// are valid states.
if (sentCnt != recvCnt) {
return true;
}
}
}
return false;
}
void ProcessGroupAgent::sync() {
// Block until all processes wants to sync.
pg_->barrier()->wait();
// block until all peers agree that all sent messages have been processed.
do {
// Finish all send/recv tasks in the thread pool
threadPool_.waitWorkComplete();
// As there could be nested RPC calls, or response callback could also
// trigger more messages to be sent, we need to wait for the thread pool
// again.
} while (hasPendingMessage());
}
void ProcessGroupAgent::startImpl() {
timeoutThreadEnabled_.store(true);
listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this);
futureTimeoutThread_ =
std::thread(&ProcessGroupAgent::pollTimedOutRPCs, this);
}
void ProcessGroupAgent::shutdownImpl() {
LOG(INFO) << "Shutting down ProcessGroupAgent on rank " << pg_->getRank()
<< ".";
{
std::unique_lock<std::mutex> lock(futureMutex_);
timeoutThreadEnabled_.store(false);
}
futureTimeoutCV_.notify_one();
futureTimeoutThread_.join();
// Abort listener thread to stop accepting new work. We need to interrupt the
// recvWork->wait() call the listener loop may be blocked in before joining
// the thread.
{
std::unique_lock<std::mutex> lock(recvWorkMutex_);
if (recvWork_) {
recvWork_->abort();
}
}
listenerThread_.join();
// Abort any pending sends to any destination rank that have not been
// completed.
{
std::lock_guard<std::mutex> lock(pendingSendMutex_);
for (auto& it : currentPendingSends_) {
const auto& pendingSends = it.second;
const auto dst = it.first;
for (const auto& send : pendingSends) {
if (!send->isCompleted()) {
LOG(INFO) << "Worker " << RpcAgent::getWorkerInfo().id_
<< " aborting pending send to destination rank " << dst;
send->abort();
}
}
}
}
// Note: calling threadPool_.waitWorkComplete() after listenerThread.join() so
// that we can finish any possible work enqueued into the thread pool, before
// python RPC handler is shutdown (see shutdown in rpc/api.py).
threadPool_.waitWorkComplete();
}
c10::intrusive_ptr<JitFuture> ProcessGroupAgent::send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds,
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
// Throw if we previously encountered an exception in ::listenLoop.
{
std::unique_lock<std::mutex> guard(listenLoopExceptionMutex_);
if (listenLoopException_) {
std::rethrow_exception(listenLoopException_);
}
}
if (!rpcAgentRunning_.load()) {
// We are trying to send but RPC has been shut down on this node. This can
// happen if we are in a shutdown sequence but background threads are still
// processing messages that result in send()s. Throw a descriptive error.
auto err = c10::str(
"Node ",
RpcAgent::getWorkerInfo().id_,
"tried to send() a message of type ",
message->type(),
" but RPC is no longer running on this node.");
TORCH_CHECK(false, err);
}
TORCH_CHECK(
to.id_ < (worker_id_t)pg_->getSize(),
"Destination rank is out of bound, got ",
to.id_,
", but world size is ",
pg_->getRank());
auto requestId = nextId();
auto future = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
if (message->isRequest()) {
// millisecond level precision of when request started.
auto futureStartTime = std::chrono::steady_clock::now();
// if passed in timeout is unset, then use the currently set default timeout
// for all RPCs.
auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
? getRpcTimeout()
: std::chrono::milliseconds(
static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
// Prepare endTime from timeout. Set infinite timeout if
// specified.
steady_clock_time_point endTime = timeout.count() == 0
? kInfiniteTimeoutTimePoint
: futureStartTime + timeout;
bool notifyThread = false;
{
std::lock_guard<std::mutex> lock{futureMutex_};
// Insert future into future map.
futures_.emplace(
std::piecewise_construct,
std::forward_as_tuple(requestId),
std::forward_as_tuple(FutureInfo(future, endTime, to.id_, timeout)));
// insert future into timeouts map to keep track of its timeout
auto& requestIds = futureTimeouts_[endTime];
requestIds.insert(requestId);
// Signal the watchdog to monitor future timeouts if this is the first
// future created or it has earlier end time than other futures in the
// map.
if (futureTimeouts_.begin()->first == endTime &&
(requestIds.size() == 1)) {
notifyThread = true;
}
}
if (notifyThread) {
// Notify the watchdog thread only after releasing the lock,
// so watchdog can acquire lock on waking up.
futureTimeoutCV_.notify_one();
}
message->setId(requestId);
++clientActiveCalls_;
} else {
future->markCompleted(IValue());
}
// Sending to ourselves: bypass the send logic and enqueue directly
// to our receiving queue.
if (to.id_ == (worker_id_t)pg_->getRank()) {
sendToSelf(std::move(message));
return future;
}
// NB: cannot directly pass ``to`` to the ``SendWork``, because it might no
// longer be alive when the ``SendWork`` is executed. For example, the
// application could query the ``WorkerInfo`` using name through the
// ``RpcAgent::getWorkerInfo`` API, and pass the ``WorkerInfo`` back here, so
// we have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerInfo``
// reference on Python side could die before ``SendWork`` uses it, and Pybind
// will not keep the Python reference alive even if it originally comes from
// the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the
// C++ land.
enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message)));
return future;
}
void ProcessGroupAgent::handleSend(const SendWork& work) {
// NOLINTNEXTLINE(clang-diagnostic-pessimizing-move)
auto serializedPayload = std::make_unique<std::string>(std::move(
wireSerialize(work.message_->payload(), work.message_->tensors())));
std::vector<torch::Tensor> preamble = {torch::tensor(
{(int64_t)pg_->getRank(),
(int64_t)serializedPayload->length(),
(int64_t)work.message_->type(),
(int64_t)work.message_->id()},
{torch::kInt64})};
// ProcessGroup is not thread-safe when sending with the same tag,
// hence the lock
std::vector<c10::intrusive_ptr<c10d::ProcessGroup::Work>> pendingSends;
const auto dst = work.to_.id_;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto serializedPayloadData = const_cast<char*>(serializedPayload->data());
auto serializedPayloadSize = serializedPayload->size();
std::string* deleteWhenDone = serializedPayload.release();
std::vector<torch::Tensor> payload = {torch::from_blob(
reinterpret_cast<void*>(serializedPayloadData),
serializedPayloadSize,
[deleteWhenDone](void*) { delete deleteWhenDone; },
{torch::kChar})};
pendingSends.reserve(2);
sendCounts_.increment(dst);
{
std::lock_guard<std::mutex> guard(sendMutexes_[dst]);
pendingSends.emplace_back(pg_->send(preamble, dst, dst /* channelTag */));
pendingSends.emplace_back(pg_->send(payload, dst, dst /* channelTag */));
}
// Write pendingSends to a global map so that they can be interrupted by
// ::shutdown().
{
std::lock_guard<std::mutex> pendingSendGuard(pendingSendMutex_);
for (auto& p : pendingSends) {
currentPendingSends_[dst].insert(p);
}
}
for (auto& pendingSend : pendingSends) {
if (!rpcAgentRunning_.load() || !pendingSend->wait()) {
// Send was interrupted or RPC is not running.
return;
}
}
// Erase the pending sends that we added since we have returned from wait.
{
std::lock_guard<std::mutex> pendingSendGuard(pendingSendMutex_);
// NB: We cannot just erase all of currentPendingSends[dst], since this
// might preemptively remove sends from other threads.
auto& set = currentPendingSends_[dst];
for (auto& p : pendingSends) {
set.erase(p);
}
}
}
void ProcessGroupAgent::sendToSelf(c10::intrusive_ptr<Message> message) {
// NOLINTNEXTLINE(modernize-avoid-bind)
threadPool_.run(std::bind(
[this](c10::intrusive_ptr<Message> message) {
// Unlike the other cases, need to add a tensor deleter, since the
// data outlives the scope of this function. It's shared_ptr<> due
// to c++11 lambda capture limitations with unique_ptr<>.
std::unique_ptr<std::string> payload;
try {
payload = std::make_unique<std::string>(
wireSerialize(message->payload(), message->tensors()));
// only increment sendCounts when the message is indeed added into
// local recv.
sendCounts_.increment(pg_->getRank());
} catch (std::exception& e) {
markFutureWithError(message->id(), e.what());
return;
}
const char* data = payload->data();
size_t len = payload->length();
std::string* delete_when_done = payload.release();
enqueueRecv(RecvWork(
getWorkerInfo(pg_->getRank()),
message->type(),
message->id(),
torch::from_blob(
(void*)data,
len,
[delete_when_done](void*) { delete delete_when_done; },
{torch::kChar})));
},
std::move(message)));
}
void ProcessGroupAgent::enqueueSend(SendWork work) {
// NB: this can be changed to use a native move capture when moved to C++14
// NOLINTNEXTLINE(modernize-avoid-bind)
threadPool_.run(std::bind(
[this](const SendWork& work) {
try {
handleSend(work);
} catch (std::exception& e) {
auto errorStr = c10::str(
"Encountered exception in ProcessGroupAgent::enqueueSend: ",
e.what(),
" on node: ",
RpcAgent::getWorkerInfo().id_);
auto exceptionMsg =
rpc::createExceptionResponse(errorStr, work.message_->id());
if (work.message_->isRequest()) {
// Mark the future with corresponding to this request with an error.
markFutureWithError(*exceptionMsg);
} else if (work.message_->isResponse()) {
// Try sending the error along.
handleSend(SendWork(work.to_, std::move(exceptionMsg)));
}
}
},
std::move(work)));
}
bool ProcessGroupAgent::handleRecv(RecvWork& work) {
torch::Tensor& payload = work.payload_;
auto data = wireDeserialize(payload.storage().data(), payload.numel());
auto message = c10::make_intrusive<Message>(
std::move(data.first), std::move(data.second), work.type_, work.id_);
if (message->isRequest()) {
++serverActiveCalls_;
c10::intrusive_ptr<JitFuture> futureResponse;
try {
futureResponse = cb_->operator()(*message, {});
} catch (const std::exception& e) {
futureResponse = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
futureResponse->setError(std::current_exception());
}
if (futureResponse->completed()) {
--serverActiveCalls_;
if (!futureResponse->hasError()) {
send(work.from_, futureResponse->value().toCustomClass<Message>());
} else {
send(
work.from_,
createExceptionResponse(
futureResponse->tryRetrieveErrorMessage(), message->id()));
}
} else {
++serverActiveAsyncCalls_;
// Callback processing returned an incomplete future. Add sending the
// response as a callback which fires when the future completes.
auto fromId = work.from_.id_;
auto requestId = work.id_;
futureResponse->addCallback(
[this, fromId, requestId](JitFuture& futureResponse) {
--serverActiveCalls_;
--serverActiveAsyncCalls_;
if (!futureResponse.hasError()) {
send(
getWorkerInfo(fromId),
futureResponse.value().toCustomClass<Message>());
} else {
send(
getWorkerInfo(fromId),
createExceptionResponse(
futureResponse.tryRetrieveErrorMessage(), requestId));
}
});
}
} else if (message->isResponse()) {
auto id = message->id();
c10::intrusive_ptr<JitFuture> jitFuture;
{
std::lock_guard<std::mutex> lock{futureMutex_};
const auto& futureInfo = futures_.find(id);
if (futureInfo == futures_.end()) {
// Received a completion for an already-processed future (such as one
// that timed out), drop the recv. By returning false, recvCounts will
// not be incremented, it will be incremented by the thread that
// determined that the future timed out.
return false;
}
// Use futureInfo before destructing it.
jitFuture = futureInfo->second.future_;
auto endTime = futureInfo->second.endTime_;
futures_.erase(id);
// look up the corresponding future by its time out and request
// ID, and remove it from the timeouts map
auto& futuresAtTime = futureTimeouts_[endTime];
auto it = futuresAtTime.find(id);
TORCH_INTERNAL_ASSERT(
it != futuresAtTime.end(),
"Error: could not find future in futureTimeouts map, race condition.");
futuresAtTime.erase(it);
if (futuresAtTime.empty()) {
// remove the key from futureTimeouts_
futureTimeouts_.erase(endTime);
}
}
futureCV_.notify_all();
--clientActiveCalls_;
if (message->type() == MessageType::EXCEPTION) {
jitFuture->setError(std::make_exception_ptr(std::runtime_error(
std::string(message->payload().begin(), message->payload().end()))));
} else {
jitFuture->markCompleted(std::move(message));
}
} else {
// TODO: pass the error back to the caller instead of crashing here.
TORCH_INTERNAL_ASSERT(false, "unrecognized message type ", message->type());
}
return true;
}
void ProcessGroupAgent::enqueueRecv(RecvWork work) {
// NOLINTNEXTLINE(modernize-avoid-bind)
threadPool_.run(std::bind(
[&](RecvWork& work) {
try {
// Only increment recvCounts if handleRecv() tells us to. We may not,
// i.e. if we process work corresponding to a future that has already
// been processed.
if (handleRecv(work)) {
recvCounts_.increment(work.from_.id_);
}
} catch (const std::exception& e) {
// Processing for this request/response failed. Log the details of the
// request.
auto fromId = work.from_.id_;
auto err = c10::str(
"Internal error while processing request of type ",
work.type_,
" on node ",
RpcAgent::getWorkerInfo().id_,
", from node ",
fromId,
" : ",
e.what());
LOG(INFO) << err;
// Still increment so that this recv is recognized as non-oustanding
// during graceful shutdown.
recvCounts_.increment(work.from_.id_);
}
},
std::move(work)));
}
void ProcessGroupAgent::markFutureWithError(Message& message) {
TORCH_INTERNAL_ASSERT(
message.type() == MessageType::EXCEPTION,
"markFutureWithError should be only called with Message that has type Exception.");
markFutureWithError(
message.id(),
std::string(message.payload().begin(), message.payload().end()));
}
void ProcessGroupAgent::markFutureWithError(int64_t id, std::string errorMsg) {
c10::intrusive_ptr<JitFuture> jitFuture;
{
std::lock_guard<std::mutex> lock{futureMutex_};
const auto& futureInfo = futures_.find(id);
if (futureInfo == futures_.end()) {
// Did not find future in map - this can occur when the future has timed
// out and been processed accordingly.
return;
}
jitFuture = futureInfo->second.future_;
auto rpcEndTime = futureInfo->second.endTime_;
futures_.erase(id);
// look up the corresponding future by its time out and request ID,
// and remove it from the timeouts map
auto& futuresAtTime = futureTimeouts_[rpcEndTime];
auto it = futuresAtTime.find(id);
TORCH_INTERNAL_ASSERT(
it != futuresAtTime.end(),
"Error: could not find future in futureTimeouts map, race condition.");
futuresAtTime.erase(it);
if (futuresAtTime.empty()) {
// remove the key from futureTimeouts_
futureTimeouts_.erase(rpcEndTime);
}
}
--clientActiveCalls_;
jitFuture->setError(std::make_exception_ptr(std::runtime_error(errorMsg)));
futureCV_.notify_all();
}
void ProcessGroupAgent::listenLoop() {
try {
listenLoopInternal();
} catch (const std::exception& e) {
// Error occured in listenLoop(). Stop receiving thread and store
// exception to indicate that the RPC agent is in an unhealthy state and
// we should shutdown.
auto err = c10::str(
"Encountered exception in ProcessGroupAgent::listenLoop(): ",
e.what(),
" on worker ",
RpcAgent::getWorkerInfo().id_,
". This means that the RPC agent is in an unhealthy state and unusable.");
LOG(ERROR) << err;
{
// Lock write to listenLoopException_ since ::send() reads from it.
std::lock_guard<std::mutex> guard(listenLoopExceptionMutex_);
listenLoopException_ = std::current_exception();
}
} catch (...) {
std::string unknownErrorMsg =
"Unknown exception occured in "
"ProcessGroupAgent::listenLoop. RPC Agent is in an unhealthy state and "
"unusable.";
LOG(ERROR) << unknownErrorMsg;
{
// Lock write to listenLoopException_ since ::send() reads from it.
std::lock_guard<std::mutex> guard(listenLoopExceptionMutex_);
listenLoopException_ =
std::make_exception_ptr(std::runtime_error(unknownErrorMsg));
}
}
}
void ProcessGroupAgent::listenLoopInternal() {
while (rpcAgentRunning_.load()) {
// rank, tensor size, message type
std::vector<torch::Tensor> preamble = {torch::empty({4}, {torch::kInt64})};
auto work = pg_->recvAnysource(preamble, pg_->getRank());
{
// Write class variable so it can be aborted by shutdown()
std::lock_guard<std::mutex> guard(recvWorkMutex_);
recvWork_ = work;
}
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
return;
}
int64_t* preamble_items = preamble.front().storage().data<int64_t>();
auto srcRank = preamble_items[0];
auto size = preamble_items[1];
MessageType type = MessageType(preamble_items[2]);
int64_t id = preamble_items[3];
std::vector<torch::Tensor> tensors = {torch::empty({size}, {torch::kChar})};
work = pg_->recv(tensors, srcRank, pg_->getRank());
{
// Write class variable so it can be aborted by shutdown()
std::lock_guard<std::mutex> guard(recvWorkMutex_);
recvWork_ = work;
}
if (!rpcAgentRunning_.load() || !work->wait() /* aborted */) {
return;
}
enqueueRecv(
RecvWork(allWorkerInfo_[srcRank], type, id, std::move(tensors[0])));
}
}
void ProcessGroupAgent::pollTimedOutRPCs() {
while (timeoutThreadEnabled_.load()) {
std::unique_lock<std::mutex> lock{futureMutex_};
steady_clock_time_point minEndTime;
// Estimate amount of time the first future will time out in, and sleep
// for that long.
// if there are no futures or the first future's RPC timeout is set to 0
// (meaning no timeout), then sleep for a set "infinity" time.
if (futureTimeouts_.empty()) {
minEndTime = kInfiniteTimeoutTimePoint;
} else {
minEndTime = futureTimeouts_.begin()->first;
}
auto shouldUpdateMinEndTimePredicate = [&, this]() -> bool {
// Notice, whoever modifies `timeoutThreadEnabled_`
// must acquire a lock on `futureMutex_`.
// Otherwise, this predicate could deadlock.
// If during evaluating the predicate, `::shutdown()` is called, then
// the predicate missed the notification before it started waiting
// on the cond var.
if (!timeoutThreadEnabled_.load()) {
return true;
}
steady_clock_time_point minEndTimeInMap = kInfiniteTimeoutTimePoint;
if (futureTimeouts_.empty()) {
minEndTimeInMap = kInfiniteTimeoutTimePoint;
} else {
minEndTimeInMap = futureTimeouts_.begin()->first;
}
return minEndTimeInMap < minEndTime;
};
bool shouldUpdateMinEndTime = true;
if (minEndTime == kInfiniteTimeoutTimePoint) {
futureTimeoutCV_.wait(lock, shouldUpdateMinEndTimePredicate);
} else {
shouldUpdateMinEndTime = futureTimeoutCV_.wait_until(
lock, minEndTime, shouldUpdateMinEndTimePredicate);
}
if (shouldUpdateMinEndTime) {
continue;
}
const auto timedOutFutures = processTimedOutFutures();
lock.unlock();
futureCV_.notify_all();
for (const auto& timedOutFuture : timedOutFutures) {
auto errStr =
fmt::format(kRpcTimeoutErrorStr, timedOutFuture.timeout_.count());
auto err = makeRPCError(errStr, RPCErrorType::TIMEOUT);
if (!timedOutFuture.future_->hasError()) {
--clientActiveCalls_;
timedOutFuture.future_->setError(
std::make_exception_ptr(std::runtime_error(err)));
// The future timed out and will not be processed by handleRecv(), even
// if we eventually get a response. In order to keep track of all
// send/recv pairs, we increment the count here.
const int dst = timedOutFuture.dstRank_;
recvCounts_.increment(dst);
}
}
}
}
const std::vector<ProcessGroupAgent::FutureInfo> ProcessGroupAgent::
processTimedOutFutures() {
std::vector<FutureInfo> timedOutFutures;
for (auto it = futureTimeouts_.begin(); it != futureTimeouts_.end();
/* intentional no increment */) {
const auto& endTime = it->first;
if (std::chrono::steady_clock::now() < endTime) {
// Since the futureTimeouts_ map is ordered by timeout, we don't need
// to check the remaining futures.
break;
} else {
const auto& futureIDs = it->second;
for (const auto& futureID : futureIDs) {
auto futureIt = futures_.find(futureID);
TORCH_INTERNAL_ASSERT(
futureIt != futures_.end(),
"Race Condition - Expected future does not exist in map");
const auto futInfo = futureIt->second;
timedOutFutures.push_back(futInfo);
futures_.erase(futureID);
}
it = futureTimeouts_.erase(it);
}
}
return timedOutFutures;
}
std::unordered_map<std::string, std::string> ProcessGroupAgent::getMetrics() {
std::unordered_map<std::string, std::string> metrics;
{
std::unique_lock<std::mutex> lock(futureMutex_);
auto futuresSize = futures_.size();
lock.unlock();
metrics[kNumPendingRequests] = c10::to_string(futuresSize);
}
metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
metrics[kNumIdleThreads] = c10::to_string(threadPool_.numAvailable());
metrics[kClientActiveCalls] = c10::to_string(clientActiveCalls_.load());
metrics[kServerActiveCalls] = c10::to_string(serverActiveCalls_.load());
metrics[kServerActiveAsyncCalls] =
c10::to_string(serverActiveAsyncCalls_.load());
if (isGILProfilingEnabled()) {
// Add time-series based metrics, just GIL wait times for now.
{
std::unique_lock<std::mutex> lock(metricsMutex_);
auto avgGilWaitTime = metrics_[GIL_WAIT_TIME]->computeAverage();
lock.unlock();
metrics[kGilAverageWaitTime] = c10::to_string(avgGilWaitTime);
}
}
return metrics;
}
void ProcessGroupAgent::addGilWaitTime(
const std::chrono::microseconds gilWaitTime) {
std::lock_guard<std::mutex> lock(metricsMutex_);
metrics_[ProcessGroupAgentMetrics::GIL_WAIT_TIME]->addData(
gilWaitTime.count());
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,290 +0,0 @@
#pragma once
#include <c10/core/thread_pool.h>
#include <c10d/PrefixStore.hpp>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <atomic>
#include <thread>
namespace torch {
namespace distributed {
namespace rpc {
constexpr auto kDefaultNumSendRecvThreads = 4;
struct TORCH_API ProcessGroupRpcBackendOptions : public RpcBackendOptions {
ProcessGroupRpcBackendOptions(
int num_send_recv_threads,
float rpc_timeout,
std::string init_method)
: RpcBackendOptions(rpc_timeout, init_method),
numSendRecvThreads(num_send_recv_threads) {
TORCH_CHECK(
num_send_recv_threads > 0,
"Cannot create ProcessGroup RPC backend with ",
num_send_recv_threads,
" threads in the thread-pool.");
}
int numSendRecvThreads;
};
// SendWork and RecvWork will be put into a task queue, and later picked up by
// worker threads from the same ThreadPool.
struct TORCH_API SendWork {
SendWork(const WorkerInfo& to, c10::intrusive_ptr<Message> message)
: to_(to), message_(std::move(message)) {}
const WorkerInfo& to_;
c10::intrusive_ptr<Message> message_;
};
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
// to allow us to run serialization/deserialization in the worker threads.
struct TORCH_API RecvWork {
RecvWork(
const WorkerInfo& from,
MessageType type,
int64_t id,
torch::Tensor&& payload)
: from_(from), type_(type), id_(id), payload_(payload) {}
const WorkerInfo& from_;
const MessageType type_;
const int64_t id_;
torch::Tensor payload_;
};
class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb);
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
const WorkerInfo& getWorkerInfo(worker_id_t id) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void join(bool shutdown = false) override;
void sync() override;
void startImpl() override;
void shutdownImpl() override;
~ProcessGroupAgent() override;
std::unordered_map<std::string, std::string> getMetrics() override;
protected:
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
c10::intrusive_ptr<JitFuture> send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
override;
// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
// Bypass handleSend() logic and send a message to self rank
virtual void sendToSelf(c10::intrusive_ptr<Message> message);
private:
class MessageCounter {
public:
explicit MessageCounter(int worldSize);
void increment(int dst);
std::vector<int64_t> snapshot();
private:
std::vector<int64_t> counters_;
std::mutex mutex_;
};
// TODO: this class should inherit from a MetricsTracker, and can be extended
// to track num_sends, recvs, average size of messages, etc.
struct AverageMetricsTracker {
std::string key_;
uint64_t currentSum_;
uint64_t currentCount_;
explicit AverageMetricsTracker(
std::string key,
uint64_t currentSum = 0,
uint64_t currentCount = 0);
void addData(uint64_t dataPoint);
double computeAverage();
};
// The FutureInfo struct stores a shared_ptr to the future, as well as
// additional information to manage timeouts and destination information,
// which is needed for termination detection.
struct FutureInfo {
c10::intrusive_ptr<JitFuture> future_;
steady_clock_time_point endTime_;
int dstRank_;
std::chrono::milliseconds timeout_;
FutureInfo(
c10::intrusive_ptr<JitFuture> future,
const steady_clock_time_point& endTime,
int dstRank,
const std::chrono::milliseconds timeout)
: future_(std::move(future)),
endTime_(endTime),
dstRank_(dstRank),
timeout_(timeout) {}
FutureInfo() = delete;
};
// handle a SendWork request. This serializes the payload inside the work
// object, and sends the message to the receiver using the underlying
// ProcessGroup.
void handleSend(const SendWork& work);
// put RecvWork into a queue and notify the worker thread
void enqueueRecv(RecvWork work);
// handle a RecvWork request. Return true if we should increment recvCounts,
// false if not (i.e. if the RPC timed out and we are getting a result after
// the timeout). This ensures that the messages accounted for in
// hasPendingMessage() are tallied properly during a graceful shutdown.
bool handleRecv(RecvWork& work);
// Loop that receives and processes messages
void listenLoopInternal();
// Calls listenLoopInternal and handles errors such as timeouts on the
// process group.
void listenLoop();
// exception_pointer correspnding to an exception raised in listenLoop (if
// there is one), and lock to guard access.
std::exception_ptr listenLoopException_;
std::mutex listenLoopExceptionMutex_;
// poll for timed out RPCs
void pollTimedOutRPCs();
// process timed out futures
const std::vector<FutureInfo> processTimedOutFutures();
// compute the remaining time for an RPC, given its end time.
const std::chrono::milliseconds getRPCRemainingTime(
const std::chrono::milliseconds& rpcEndTime) const;
// a helper function to mark a future in the futures_ map with a message. The
// future is marked with the passed in message, and then removed from the
// futures_ map. It is also removed from the futureTimeouts_ map since these
// maps are kept in sync.
void markFutureWithError(Message& message);
void markFutureWithError(int64_t id, std::string errorMsg);
// Note [Termination Detection]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// RpcAgent implementations must properly detect termination. Otherwise, it
// would result in message loss, RRef leak, or process hang. It is not
// sufficient to just wait for the thread pool to finish processing all tasks
// after all processes hit the join function. There could be nested rpc/remote
// calls, meaning that an empty task queue in the thread pool does not mean
// there will be no tasks added in the future. Moreover, in the listenLoop,
// there is a period of time when the message has been received but not yet
// inserted into the thread pool, which also suggests that the empty task
// queue is not a good indicator for termination.
//
// To detect termination, each ProcessGroupAgent maintains a sent message
// counter and a received message counter. The sent message counter is
// incremented whenever a message is sent, and the receive message counter is
// only incremented when a message has been processed. During termination, all
// ProcessGroupAgent instances run an allgather to collect counters from all
// peers, which means that all agents will have a consistent view on the
// message count snapshot. They would only terminate if all sent/received
// message counters match.
bool hasPendingMessage();
int64_t nextId() {
return ++nextId_;
}
c10::intrusive_ptr<::c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, worker_id_t> nameMap_;
std::vector<WorkerInfo> allWorkerInfo_;
// record the number of messages sent to and received from each peer. The recv
// counter is only marked after the message is processed. Join uses allgather
// to collect all counts from all peers, uses these counters to detect global
// termination and only exit when all sent messages are processed.
MessageCounter sendCounts_;
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_;
// A thread to poll existing futures and check for timed out ones.
std::thread futureTimeoutThread_;
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
c10::intrusive_ptr<c10d::ProcessGroup::Work> recvWork_;
// Map of dst rank to current oustanding sends that we are waiting on. In the
// case of a call to ::shutdown() while we are still waiting on these sends,
// the pending sends contained in this map will be aborted, allowing the
// waiting thread to be unblocked.
std::unordered_map<
worker_id_t,
std::set<c10::intrusive_ptr<c10d::ProcessGroup::Work>>>
currentPendingSends_;
// Lock to serialize access to the above map.
std::mutex pendingSendMutex_;
// A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive,
// hence using multiple threads to speed it up.
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
// not yield in the middle of execution to wait for IO, and resume the IO
// is done. This would result in deadlocks when we have nested RPC calls.
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
// This is just a temporary solution for (2).
ThreadPool threadPool_;
// Atomic to indicate whether the timeout thread is enabled.
std::atomic<bool> timeoutThreadEnabled_;
// Mapping of request id to FutureInfo struct.
std::unordered_map<int64_t, FutureInfo> futures_;
// A map to keep track of when futures time out. The map is keyed by the time
// (millisecond level precision) the future will expire. This is so that timed
// out futures can be efficiently cleaned up, and we can quickly exit if we
// find a future that has not timed out. The values correspond to an
// unordered_set of future ids that started at that time. This map must be
// kept in sync with the above futures_ map.
std::map<steady_clock_time_point, std::unordered_set<int64_t>>
futureTimeouts_;
mutable std::mutex futureMutex_;
mutable std::condition_variable futureCV_;
// CV to wake up watchdog thread that watches for timed out futures.
std::condition_variable futureTimeoutCV_;
// Metrics tracked for ProcessGroupAgent.
enum ProcessGroupAgentMetrics {
GIL_WAIT_TIME = 0,
N_METRICS,
};
std::mutex metricsMutex_;
std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
std::atomic<int32_t> clientActiveCalls_{0};
std::atomic<int32_t> serverActiveCalls_{0};
std::atomic<int32_t> serverActiveAsyncCalls_{0};
};
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,152 +0,0 @@
#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
std::string fromVec(const std::vector<char>& vec) {
return std::string(vec.begin(), vec.end());
}
FaultyProcessGroupAgent::FaultyProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<::c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb,
const std::vector<std::string>& messagesToFail,
const std::unordered_map<std::string, float>& messageTypesToDelay,
int failNumSends)
: ProcessGroupAgent(
store,
std::move(workerName),
std::move(pg),
numSendRecvThreads,
rpcTimeout,
std::move(cb)),
failNumSends_(failNumSends),
messageTypesToFail_(parseMessagesToFailInput(messagesToFail)),
messageTypesToDelay_(parseMessagesToDelay(messageTypesToDelay)) {}
std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
const std::vector<std::string>& messagesToFail) const {
// Since we can only pass strings corresponding to the Message Types from the
// python tests, we must parse the list of strings and resolve the actual
// types. We will then check this list of types in the send function to
// determine whether we should fail or not.
std::vector<MessageType> messageTypesToFail;
messageTypesToFail.reserve(messagesToFail.size());
for (const auto& msgString : messagesToFail) {
messageTypesToFail.push_back(messageStringToType(msgString));
}
return messageTypesToFail;
}
std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
parseMessagesToDelay(const std::unordered_map<std::string, float>&
messageTypesToDelay) const {
std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
for (const auto& messagePair : messageTypesToDelay) {
float delay = messagePair.second;
TORCH_CHECK(
delay >= 0,
"Delays passed to FaultyProcessGroupAgent must be non-negative.")
delayMessages.insert({messageStringToType(messagePair.first), delay});
}
return delayMessages;
}
c10::intrusive_ptr<JitFuture> FaultyProcessGroupAgent::send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds,
const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
// We only fail control messages that have been specified by the test case.
// For all other messages, we just send them without any failures.
if (!shouldFailMessage(message->type())) {
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
}
// This send function checks the failMessageCountMap_ to check whether
// we must fail the next send. If the send must be failed, we set an error
// on the returned future immediately and increment the counter in the map,
// otherwise we just call the ProcessGroupAgent send.
const auto key = fromVec(message->payload());
std::unique_lock<std::mutex> lock(failMapMutex_);
auto it = failMessageCountMap_.find(key);
if (it == failMessageCountMap_.end()) {
failMessageCountMap_[key] = 0;
}
if (failMessageCountMap_[key] < failNumSends_) {
failMessageCountMap_[key]++;
lock.unlock();
auto jitFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError(
c10::str("Send attempt failed intentionally for ", key),
RPCErrorType::INTENTIONAL_FAILURE))));
return jitFuture;
} else {
lock.unlock();
return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
}
}
void FaultyProcessGroupAgent::enqueueSend(SendWork work) {
float msgDelay = getDelayForMessage(work.message_->type());
if (msgDelay != 0) {
// Sleep for the specified delay for the message.
std::this_thread::sleep_for(std::chrono::milliseconds(
static_cast<int>(msgDelay * kSecToMsConversion)));
}
ProcessGroupAgent::enqueueSend(std::move(work));
}
void FaultyProcessGroupAgent::sendToSelf(c10::intrusive_ptr<Message> message) {
float msgDelay = getDelayForMessage(message->type());
if (msgDelay != 0) {
// Sleep for the specified delay for the message.
std::this_thread::sleep_for(std::chrono::milliseconds(
static_cast<int>(msgDelay * kSecToMsConversion)));
}
ProcessGroupAgent::sendToSelf(std::move(message));
}
bool FaultyProcessGroupAgent::shouldFailMessage(MessageType type) const {
// Return true if the input message type is in the messageTypesToFail_ list
return (
std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
messageTypesToFail_.end());
}
float FaultyProcessGroupAgent::getDelayForMessage(MessageType type) const {
const auto& it = messageTypesToDelay_.find(type);
return it == messageTypesToDelay_.end() ? 0 : it->second;
}
MessageType FaultyProcessGroupAgent::messageStringToType(
const std::string& messageString) const {
// Lazily constructed map that returns string to message type mapping
static std::unordered_map<std::string, MessageType> msgMap = {
{"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
{"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
{"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
{"CLEANUP_AUTOGRAD_CONTEXT_REQ",
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
{"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
{"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
{"PYTHON_CALL", MessageType::PYTHON_CALL},
{"SCRIPT_CALL", MessageType::SCRIPT_CALL},
{"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
{"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
const auto& it = msgMap.find(messageString);
TORCH_CHECK(
it != msgMap.end(),
"No mapping to rpc::MessageType exists for ",
messageString);
return it->second;
}
} // namespace rpc
} // namespace distributed
} // namespace torch

View File

@ -1,99 +0,0 @@
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/process_group_agent.h>
namespace torch {
namespace distributed {
namespace rpc {
struct TORCH_API FaultyProcessGroupRpcBackendOptions
: public ProcessGroupRpcBackendOptions {
FaultyProcessGroupRpcBackendOptions(
int num_send_recv_threads,
float rpc_timeout,
std::string init_method,
std::vector<std::string> messages_to_fail,
std::unordered_map<std::string, float> messages_to_delay,
int num_fail_sends = 0)
: ProcessGroupRpcBackendOptions(
num_send_recv_threads,
rpc_timeout,
std::move(init_method)),
messagesToFail(std::move(messages_to_fail)),
messagesToDelay(std::move(messages_to_delay)),
numFailSends(num_fail_sends) {
TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
}
std::vector<std::string> messagesToFail;
std::unordered_map<std::string, float> messagesToDelay;
int numFailSends;
};
class TORCH_API FaultyProcessGroupAgent : public ProcessGroupAgent {
public:
FaultyProcessGroupAgent(
const c10::intrusive_ptr<::c10d::Store>& store,
std::string workerName,
c10::intrusive_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb,
const std::vector<std::string>& messagesToFail,
const std::unordered_map<std::string, float>& messageTypesToDelay,
int failNumSends = 0);
// Faulty send function for this class.
c10::intrusive_ptr<JitFuture> send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
override;
protected:
// This function checks the messageTypesToFail_ to determine whether to use
// the faulty send or not.
virtual bool shouldFailMessage(MessageType type) const;
private:
// Overrides ProcessGroupAgent's enqueueSend to inject delays.
void enqueueSend(SendWork work) override;
// Override ProcessGroupAgent's sendToSelf to inject delays.
void sendToSelf(c10::intrusive_ptr<Message> message) override;
// This function parses the list of strings passed in by the python tests and
// resolves the Message Types that must use the faulty send.
std::vector<MessageType> parseMessagesToFailInput(
const std::vector<std::string>& messagesToFail) const;
// Returns amount of time in seconds to delay sending of the given message
// type.
float getDelayForMessage(MessageType type) const;
// Parse message types that we should inject arbitrary delays for.
std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
const std::unordered_map<std::string, float>& messageTypesToDelay) const;
// Number of sends to intentionally fail before allowing one to succeed.
const int failNumSends_;
// Vector of the MessageTypes that we must use the faulty send for. This is
// parsed based on a list of strings passed in by the python tests.
const std::vector<MessageType> messageTypesToFail_;
// Mapping of message types to amount we should delay send for in the ::send()
// function.
std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
// Map to track the number of sends we've failed for each RPC.
std::unordered_map<std::string, int> failMessageCountMap_;
// Mutex to guard failMessageCountMap_
std::mutex failMapMutex_;
MessageType messageStringToType(const std::string& messageString) const;
};
} // namespace rpc
} // namespace distributed
} // namespace torch