mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Separate monitoring thread into a class in PGNCCL (#153977)
This is the start of a series of efforts to consolidating auxiliary threads in PGNCCL, aka watchdog and heartbeat_monitoring threads. Right now we launch these two threads per PG instances, i.e., if users create hundred or thousand instances of PG or subPGs, we will end up with that twice many side threads which is not efficient. We have a RFC to consolidate them (https://github.com/pytorch/pytorch/issues/146956). Right now both threads are assigned with so many functionalities so it is hard to do the consolidations in one shot, we will try to split it into at least two steps (PRs) to make it easier to test and review. We did our first attemp in https://github.com/pytorch/pytorch/pull/153668 but we also want to try to see if we can make monitoring thread a class. This PR is doing the first step to make monitoring thread a class. The next step to also extract watchdog to be a separate class so that we know its dependency. What we did in this PR: 1. Move all related variables and methods into a class named `HeartbeatMonitor`. 2. Correct some errors in the original logics inside monitoring thread loop. 3. Move the error propagation check to watchdog thread which is more relevant. This is totally fine since we rolled out EventCache out fully so watchdog hang is rare now. Today there are two major functions inside heartbeat monitoring thread today: 1. Check the heartbeat of watchdog thread every 8 minutes. If no heartbeat detected and we are sure monitoring thread has not been stopped, we will kill the program by SIG_ABORT. 2. We check TCPStore every 30 sec to see if any watchdog timeout happens on other ranks, if so we will initiate a dump signal on the current rank as well. (We do this only in the default PG) Differential Revision: [D75799278](https://our.internmc.facebook.com/intern/diff/D75799278) Pull Request resolved: https://github.com/pytorch/pytorch/pull/153977 Approved by: https://github.com/kwen2501, https://github.com/d4l3k
This commit is contained in:
@ -179,7 +179,12 @@ class ProcessGroupNCCLNoHeartbeatCaught
|
||||
int rank,
|
||||
int size,
|
||||
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> opts)
|
||||
: ProcessGroupNCCLTimedOutErrors(store, rank, size, std::move(opts)) {}
|
||||
: ProcessGroupNCCLTimedOutErrors(store, rank, size, std::move(opts)) {
|
||||
// Override the heartbeat monitor function to make sure that we capture
|
||||
// the exception in the monitor thread because we cannot try-catch it in
|
||||
// the main thread and we set a flag for the main thread to check.
|
||||
heartbeatMonitor_ = std::make_unique<TestHeartbeatMonitor>(this);
|
||||
}
|
||||
|
||||
std::mutex& getWatchdogMutex() {
|
||||
return workMetaListMutex_;
|
||||
@ -195,18 +200,22 @@ class ProcessGroupNCCLNoHeartbeatCaught
|
||||
asyncDebugDump.wait();
|
||||
}
|
||||
|
||||
protected:
|
||||
// Override the heartbeat monitor function to make sure that we capture
|
||||
// the exception in the monitor thread because we cannot try-catch it in
|
||||
// the main thread and we set a flag for the main thread to check.
|
||||
void heartbeatMonitor() override {
|
||||
try {
|
||||
c10d::ProcessGroupNCCL::heartbeatMonitor();
|
||||
} catch (std::runtime_error& e) {
|
||||
hasMonitorThreadCaughtError_ = true;
|
||||
}
|
||||
}
|
||||
class TestHeartbeatMonitor : public c10d::ProcessGroupNCCL::HeartbeatMonitor {
|
||||
public:
|
||||
using HeartbeatMonitor::HeartbeatMonitor;
|
||||
|
||||
void runLoop() override {
|
||||
try {
|
||||
c10d::ProcessGroupNCCL::HeartbeatMonitor::runLoop();
|
||||
} catch (std::runtime_error& e) {
|
||||
// Safe cast because we know it's a ProcessGroupNCCLNoHeartbeatCaught
|
||||
auto* pg = static_cast<ProcessGroupNCCLNoHeartbeatCaught*>(pg_);
|
||||
pg->hasMonitorThreadCaughtError_ = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
// It's really hard to unit test std::abort. So we override it instead.
|
||||
// Commented this override, we do see process aborted with core dump without
|
||||
// this override.
|
||||
|
@ -934,7 +934,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
store_(std::move(store)),
|
||||
options_(std::move(options)),
|
||||
terminateProcessGroup_(false),
|
||||
terminateHeartbeatMonitorThread_(false),
|
||||
local_id_(process_group_id++),
|
||||
intraNodeComm_(initIntraNodeComm()) {
|
||||
TORCH_CHECK_WITH(
|
||||
@ -956,24 +955,12 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
|
||||
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
|
||||
// or change its name to reflect that dump happens on exception including
|
||||
// both timeout and other errors.
|
||||
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, true) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false);
|
||||
// logging C++ stack isn't safe. Introduce a variable to control it.
|
||||
logCppStackOnUncleanShutdown_ =
|
||||
getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true);
|
||||
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
|
||||
heartbeat_ = 1ULL;
|
||||
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
|
||||
cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true));
|
||||
heartbeatTimeoutInSec_ =
|
||||
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
|
||||
waitTimeoutDumpInMilSec_ =
|
||||
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/);
|
||||
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
|
||||
traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000);
|
||||
enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
|
||||
// store_ usually is wrapped with PrefixStore and the prefix is different
|
||||
@ -1007,6 +994,10 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the heartbeat monitor instance. This has to be done before
|
||||
// the watchdog thread is launched to avoid the error.
|
||||
heartbeatMonitor_ = std::make_unique<HeartbeatMonitor>(this);
|
||||
|
||||
#ifdef ENABLE_NCCL_ERROR_CHECKING
|
||||
// in blockingWait mode, we don't need to enable the watchdog thread to check
|
||||
// the timeout or nccl error because the main thread would throw an exception
|
||||
@ -1033,10 +1024,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: "
|
||||
<< "NCCL version: " << ncclVersion
|
||||
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
|
||||
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_
|
||||
<< ", TORCH_NCCL_PROPAGATE_ERROR: " << propagatePgError_
|
||||
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
|
||||
<< waitTimeoutDumpInMilSec_
|
||||
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
|
||||
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
|
||||
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
|
||||
@ -1045,15 +1033,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
|
||||
<< shouldAllCommunicatorsRegisterAllTensors()
|
||||
#endif // NCCL_HAS_COMM_REGISTER
|
||||
<< ", TORCH_NCCL_ENABLE_MONITORING: "
|
||||
<< monitorThreadEnabled_.load()
|
||||
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
|
||||
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << traceBufferSize_
|
||||
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
|
||||
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
|
||||
<< ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_
|
||||
<< ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: "
|
||||
<< logCppStackOnUncleanShutdown_;
|
||||
<< ", TORCH_NCCL_CUDA_EVENT_CACHE: " << cudaEventCacheEnabled_;
|
||||
|
||||
getGlobalRankStartAndStride(
|
||||
options_->global_ranks_in_group,
|
||||
@ -1458,8 +1440,7 @@ void ProcessGroupNCCL::abort() {
|
||||
|
||||
// We need to wait for abort to finish before we can safely shut down
|
||||
// heartbeat monitoring thread.
|
||||
terminateHeartbeatMonitorThread_.store(true);
|
||||
monitorWakeUpCV_.notify_one();
|
||||
heartbeatMonitor_->stop();
|
||||
}
|
||||
|
||||
// Difference between `abort()` and `shutdown()`:
|
||||
@ -1508,8 +1489,7 @@ void ProcessGroupNCCL::shutdown() {
|
||||
}
|
||||
// Watchdog thread exiting, retire heartbeat monitoring thread now to avoid
|
||||
// false alarm
|
||||
terminateHeartbeatMonitorThread_.store(true);
|
||||
monitorWakeUpCV_.notify_one();
|
||||
heartbeatMonitor_->stop();
|
||||
// Destroy the communicator, reclaim resources
|
||||
LOG(INFO) << logPrefix() << "Watchdog joined, destroying NCCL communicators.";
|
||||
{
|
||||
@ -1564,19 +1544,14 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
terminateProcessGroup_.store(true);
|
||||
workMetaListCV_.notify_one();
|
||||
// Tell heartbeat thread:
|
||||
terminateHeartbeatMonitorThread_.store(true);
|
||||
monitorWakeUpCV_.notify_one();
|
||||
heartbeatMonitor_->stop();
|
||||
|
||||
// Wait for all threads to finish before returning
|
||||
if (ncclCommWatchdogThread_.joinable()) {
|
||||
ncclCommWatchdogThread_.join();
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
|
||||
}
|
||||
if (ncclHeartbeatMonitorThread_.joinable()) {
|
||||
ncclHeartbeatMonitorThread_.join();
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "ProcessGroupNCCL heart beat monitor thread joined.";
|
||||
}
|
||||
heartbeatMonitor_->join();
|
||||
if (onCompletionHookThread_.joinable()) {
|
||||
onCompletionHookThread_.join();
|
||||
LOG(INFO) << logPrefix()
|
||||
@ -1625,17 +1600,21 @@ static long computeDeltaMS(
|
||||
.count();
|
||||
}
|
||||
|
||||
std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
|
||||
void ProcessGroupNCCL::setEnableNanCheck(bool enableNanCheck) {
|
||||
enableNanCheck_ = enableNanCheck;
|
||||
}
|
||||
|
||||
std::string ProcessGroupNCCL::HeartbeatMonitor::getNCCLWatchdogTimeoutErrorMsg(
|
||||
const std::string& extraMsg) {
|
||||
return c10::str(
|
||||
logPrefix(),
|
||||
pg_->logPrefix(),
|
||||
"Received a dump signal due to a collective timeout from ",
|
||||
extraMsg,
|
||||
" and we will try our best to dump the debug info. ",
|
||||
"Last enqueued NCCL work: ",
|
||||
pgStatus_->lastEnqueuedSeq,
|
||||
pg_->pgStatus_->lastEnqueuedSeq,
|
||||
", last completed NCCL work: ",
|
||||
pgStatus_->lastCompletedSeq,
|
||||
pg_->pgStatus_->lastCompletedSeq,
|
||||
".",
|
||||
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
|
||||
"sizes used across ranks, the order of collectives is not same for all ranks ",
|
||||
@ -1644,37 +1623,91 @@ std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
|
||||
"bugs in the communications library (e.g. NCCL), etc. ");
|
||||
}
|
||||
|
||||
std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutExitMsg(
|
||||
std::string ProcessGroupNCCL::HeartbeatMonitor::getNCCLWatchdogTimeoutExitMsg(
|
||||
const std::string& exitReason) {
|
||||
return c10::str(
|
||||
logPrefix(),
|
||||
pg_->logPrefix(),
|
||||
"Terminating the process after attempting to dump debug info, due to ",
|
||||
exitReason,
|
||||
".");
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::setEnableNanCheck(bool enableNanCheck) {
|
||||
enableNanCheck_ = enableNanCheck;
|
||||
void ProcessGroupNCCL::HeartbeatMonitor::setLastWorkListUpdateTime(
|
||||
std::chrono::time_point<std::chrono::steady_clock> time) {
|
||||
// We intentially let the race condition to happen but this is ok
|
||||
// as long as we update the time, we know we are making progress.
|
||||
lastWorkListUpdateTime_ = time;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
ProcessGroupNCCL::HeartbeatMonitor::HeartbeatMonitor(ProcessGroupNCCL* pg) {
|
||||
pg_ = pg;
|
||||
heartbeatTimeoutInSec_ =
|
||||
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
|
||||
waitTimeoutDumpInMilSec_ =
|
||||
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/);
|
||||
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
|
||||
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
|
||||
// or change its name to reflect that dump happens on exception including
|
||||
// both timeout and other errors.
|
||||
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, true);
|
||||
// logging C++ stack isn't safe. Gate it with an ENV.
|
||||
logCppStackOnUncleanShutdown_ =
|
||||
getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true);
|
||||
watchdogHeartbeatMonitorEnabled_ =
|
||||
getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true);
|
||||
|
||||
// print out ENV settings for the heartbeat monitor thread.
|
||||
LOG(INFO)
|
||||
<< pg_->logPrefix() << "HeartbeatMonitor environments: "
|
||||
<< "TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): "
|
||||
<< watchdogHeartbeatMonitorEnabled_
|
||||
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeoutOrEx_
|
||||
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: " << waitTimeoutDumpInMilSec_
|
||||
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
|
||||
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
|
||||
<< ", TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: "
|
||||
<< logCppStackOnUncleanShutdown_;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::HeartbeatMonitor::stop() {
|
||||
terminateHeartbeatMonitorThread_.store(true);
|
||||
monitorWakeUpCV_.notify_one();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::HeartbeatMonitor::start() {
|
||||
TORCH_CHECK(
|
||||
!ncclHeartbeatMonitorThread_.joinable(),
|
||||
"HeartbeatMonitor thread already started");
|
||||
ncclHeartbeatMonitorThread_ =
|
||||
std::thread(&ProcessGroupNCCL::HeartbeatMonitor::runLoop, this);
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::HeartbeatMonitor::join() {
|
||||
if (ncclHeartbeatMonitorThread_.joinable()) {
|
||||
ncclHeartbeatMonitorThread_.join();
|
||||
LOG(INFO) << pg_->logPrefix()
|
||||
<< "ProcessGroupNCCL heart beat monitor thread joined.";
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::HeartbeatMonitor::runLoop() {
|
||||
c10::setThreadName("pt_nccl_heartbt");
|
||||
|
||||
uint64_t heartBeatCounter = 0ULL;
|
||||
std::string errorMsg;
|
||||
std::string exitReason;
|
||||
bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0);
|
||||
int monitorPollInterval = checkDumpSignal || propagatePgError_
|
||||
? coordCheckIntervalMilSec_
|
||||
: heartbeatTimeoutInSec_ * 1000;
|
||||
bool checkDumpSignal = (dumpOnTimeoutOrEx_ && pg_->getUid() == 0);
|
||||
int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
|
||||
: heartbeatTimeoutInSec_ * 1000;
|
||||
auto lastTimePollStore = std::chrono::steady_clock::now();
|
||||
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
|
||||
std::optional<DumpPipe> dumpPipe = std::nullopt;
|
||||
if (local_id_ == 0) {
|
||||
|
||||
if (pg_->getUid() == 0) {
|
||||
// DumpPipe is one per-trainer process, and its convenient to name them
|
||||
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
|
||||
// the global PG and has globally unique rank ids across trainers.
|
||||
dumpPipe.emplace(rank_);
|
||||
dumpPipe.emplace(pg_->globalRank());
|
||||
}
|
||||
while (true) {
|
||||
// This won't have any lock since this lock is only used here.
|
||||
@ -1691,11 +1724,6 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
}
|
||||
auto currentTime = std::chrono::steady_clock::now();
|
||||
|
||||
if (propagatePgError_) {
|
||||
// Check and set remote error if it has not been set before
|
||||
checkAndSetRemoteError();
|
||||
}
|
||||
|
||||
// We put extra functionality in the thread for the default PG (aka,
|
||||
// local_id_=0) because the signal is same across different PGs. We only
|
||||
// need to run once per process to avoid duplicate things performed in too
|
||||
@ -1725,7 +1753,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
lastTimePollStore = currentTime;
|
||||
auto handleError = [&](const std::string& errorMessage) {
|
||||
LOG(WARNING)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "Failed to check the \"should dump\" flag on TCPStore, "
|
||||
<< "(maybe TCPStore server has shut down too early), with error: "
|
||||
<< errorMessage;
|
||||
@ -1737,7 +1765,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
bool checkExceptionDump = false;
|
||||
try {
|
||||
checkExceptionDump =
|
||||
globalStore_->check({std::string(kStoreDumpKey)});
|
||||
pg_->globalStore()->check({std::string(kStoreDumpKey)});
|
||||
} catch (const c10::DistNetworkError& e) {
|
||||
handleError(e.msg());
|
||||
} catch (const std::exception& e) {
|
||||
@ -1748,19 +1776,19 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
int timeOutRank = -1;
|
||||
if (!shouldDump_.load()) {
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "Observed flight recorder dump signal from another rank via TCPStore.";
|
||||
}
|
||||
shouldDump_.store(true);
|
||||
try {
|
||||
auto vec = globalStore_->get(std::string(kStoreDumpKey));
|
||||
auto vec = pg_->globalStore()->get(std::string(kStoreDumpKey));
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
vec.size() == sizeof(int),
|
||||
"Invalid size for the timeout rank ID");
|
||||
std::memcpy(&timeOutRank, vec.data(), vec.size());
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << logPrefix()
|
||||
LOG(ERROR) << pg_->logPrefix()
|
||||
<< "Failed to get timeout rank ID from TCPStore."
|
||||
<< e.what();
|
||||
}
|
||||
@ -1776,14 +1804,14 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
heartbeatTimeoutInSec_ * 1000l) {
|
||||
// Check the heart beat of watchdog thread.
|
||||
lastTimeHeartBeatCheck = currentTime;
|
||||
auto heartbeat = heartbeat_.load();
|
||||
auto heartbeat = pg_->getWatchdogHeartbt();
|
||||
if (heartbeat != heartBeatCounter) {
|
||||
heartBeatCounter = heartbeat;
|
||||
} else {
|
||||
shouldDump_.store(true);
|
||||
// Watchdog heartbeat timeout.
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
pg_->logPrefix(),
|
||||
"ProcessGroupNCCL's watchdog got stuck for ",
|
||||
heartbeatTimeoutInSec_,
|
||||
" seconds without making progress in monitoring enqueued collectives. ",
|
||||
@ -1804,8 +1832,9 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// recorder and dump. After dump, the training should continue.
|
||||
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
|
||||
// best effort dump, not waiting for the dump here
|
||||
std::future<bool> fut = std::async(
|
||||
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
|
||||
std::future<bool> fut = std::async(std::launch::async, [this]() {
|
||||
return this->pg_->dumpDebuggingInfo();
|
||||
});
|
||||
}
|
||||
}
|
||||
LOG(ERROR) << errorMsg;
|
||||
@ -1824,19 +1853,19 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// local disk)
|
||||
bool dumpStackTrace = true;
|
||||
::c10d::C10dLoggingData debugLog;
|
||||
debugLog.integers["pg_id"] = static_cast<int64_t>(local_id_);
|
||||
debugLog.integers["rank"] = rank_;
|
||||
debugLog.integers["global_rank"] = globalRank();
|
||||
debugLog.integers["world_size"] = getSize();
|
||||
debugLog.integers["pg_id"] = static_cast<int64_t>(pg_->getUid());
|
||||
debugLog.integers["rank"] = pg_->getRank();
|
||||
debugLog.integers["global_rank"] = pg_->globalRank();
|
||||
debugLog.integers["world_size"] = pg_->getSize();
|
||||
debugLog.strings["flight_recorder_version"] = c10d::version_val_str;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
std::future<bool> asyncDebugDump =
|
||||
std::async(std::launch::async, [this, dumpStackTrace]() {
|
||||
return this->dumpDebuggingInfo(dumpStackTrace);
|
||||
return this->pg_->dumpDebuggingInfo(dumpStackTrace);
|
||||
});
|
||||
|
||||
// wait for the dump until timeout - log data
|
||||
auto complete = waitForFutureOrTimeout(
|
||||
auto complete = pg_->waitForFutureOrTimeout(
|
||||
asyncDebugDump,
|
||||
std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
|
||||
"Flight recorder dump in heartbeatMonitor",
|
||||
@ -1845,7 +1874,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
|
||||
if (complete) {
|
||||
LOG(INFO)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "Finished flight recorder successfully. Output can be analyzed using the fr_trace script.";
|
||||
if (i > 0) {
|
||||
debugLog.strings["exception_msg"] = "Dump with stack trace failed.";
|
||||
@ -1873,39 +1902,42 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
futStatus != std::future_status::deferred,
|
||||
"Expected the future to have been launched eagerly.");
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
|
||||
}
|
||||
} else {
|
||||
VLOG(2)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "GIL checker was not registered, perhaps this is a no-python build?";
|
||||
}
|
||||
|
||||
// Dump the c++ stacktraces.
|
||||
auto& cpp_dumper = get_cpp_trace_dumper();
|
||||
if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) {
|
||||
LOG(INFO) << logPrefix() << "Dumping c++ stacktraces:";
|
||||
cpp_dumper.value()(
|
||||
[&](const std::string& line) { LOG(INFO) << logPrefix() << line; });
|
||||
LOG(INFO) << logPrefix() << "Finished c++ stacktraces dump.";
|
||||
LOG(INFO) << pg_->logPrefix() << "Dumping c++ stacktraces:";
|
||||
cpp_dumper.value()([&](const std::string& line) {
|
||||
LOG(INFO) << pg_->logPrefix() << line;
|
||||
});
|
||||
LOG(INFO) << pg_->logPrefix() << "Finished c++ stacktraces dump.";
|
||||
}
|
||||
|
||||
// There are two possible cases for the watchdog thread exit:
|
||||
// Case one: desync report runs quickly, and it follows the step:
|
||||
// collective timeout -> desync -> exception handling -> destructors
|
||||
// -> set terminateHeartbeatMonitorThread_ -> notify monitorWakeUpCV_.
|
||||
// So the code either early returns above or will skip the sleep below.
|
||||
// Case two: desync might be slow or get stuck. Or we get stuck in
|
||||
// destructors, we will sleep for some time before calling std::abort() to
|
||||
// kill the whole process.
|
||||
if ((terminateProcessGroup_.load() || desyncDebug_ || shouldDump_.load()) &&
|
||||
// collective timeout -> desync -> exception handling -> throwing exception.
|
||||
// The program will exit because of exception thrown and the code below will
|
||||
// not be run.
|
||||
//
|
||||
// Case two: desync might be slow or get stuck and we need to wait
|
||||
// extra time to avoid we kill the program too early.
|
||||
//
|
||||
// Or we get stuck in destructors, we will sleep for some time before calling
|
||||
// std::abort() to kill the whole process.
|
||||
if ((pg_->terminateProcessGroup_.load() || shouldDump_.load()) &&
|
||||
!terminateHeartbeatMonitorThread_.load()) {
|
||||
// Leave another two mins for desync report generation or process group
|
||||
// destroy.
|
||||
std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_));
|
||||
LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
|
||||
<< " waiting for desync report or process group destroy.";
|
||||
LOG(INFO)
|
||||
<< pg_->logPrefix() << "slept for " << heartbeatTimeoutInSec_
|
||||
<< " because we want to wait longer to verify there is indeed a watchdog hang.";
|
||||
}
|
||||
|
||||
// At this point, we either already sleep for another `heartbeatTimeoutInSec_`
|
||||
@ -1917,20 +1949,19 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// We already log completion inside the thread, so it may not be necessary to
|
||||
// check the return value here. We mainly use a future so we can exit early
|
||||
// if done.
|
||||
|
||||
if (!terminateHeartbeatMonitorThread_.load()) {
|
||||
// Create a error message reported from MonitorThread, so
|
||||
// we throw exception and make the whole process to be killed.
|
||||
// TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
|
||||
// url here.
|
||||
if (monitorThreadEnabled_.load()) {
|
||||
terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason));
|
||||
if (watchdogHeartbeatMonitorEnabled_) {
|
||||
pg_->terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason));
|
||||
} else {
|
||||
// Ideally we want to merge this one with the above one, but we are going
|
||||
// to remove the kill switch for monitor thread soon, so we keep this one
|
||||
// for now.
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< pg_->logPrefix()
|
||||
<< "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process"
|
||||
<< "after attempting to dump debug info, due to " << exitReason
|
||||
<< ".";
|
||||
@ -1943,8 +1974,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
|
||||
try {
|
||||
VLOG(2) << logPrefix() << "Process group watchdog thread started!";
|
||||
ncclHeartbeatMonitorThread_ =
|
||||
std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
|
||||
heartbeatMonitor_->start();
|
||||
watchdogHandler();
|
||||
VLOG(2) << logPrefix()
|
||||
<< "Process group watchdog thread terminated normally";
|
||||
@ -2104,6 +2134,10 @@ const int& ProcessGroupNCCL::globalRank() const {
|
||||
return globalRank;
|
||||
}
|
||||
|
||||
const c10::intrusive_ptr<Store>& ProcessGroupNCCL::globalStore() const {
|
||||
return globalStore_;
|
||||
}
|
||||
|
||||
const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
|
||||
if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
|
||||
static std::vector<uint64_t> globalRanks(size_);
|
||||
@ -2237,7 +2271,8 @@ static int getRootIndex(const int rank, const int nRanks, const int nIds) {
|
||||
|
||||
void ProcessGroupNCCL::watchdogHandler() {
|
||||
bool done = false;
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
heartbeatMonitor_->setLastWorkListUpdateTime(
|
||||
std::chrono::steady_clock::now());
|
||||
auto lastStatusUpdateTime = std::chrono::steady_clock::now();
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
||||
|
||||
@ -2301,6 +2336,11 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
lastStatusUpdateTime = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
if (propagatePgError_) {
|
||||
// Check and set remote error if it has not been set before
|
||||
checkAndSetRemoteError();
|
||||
}
|
||||
|
||||
for (auto it = workMetaList_.begin(); it != workMetaList_.end();
|
||||
/* no increment */) {
|
||||
auto& work = *it;
|
||||
@ -2361,16 +2401,14 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// try to notify other ranks via global TCPStore to dump the flight
|
||||
// recorder when a collective timeout or exception happens. Flight
|
||||
// recorder behavior is independent of desync Debug.
|
||||
if (dumpOnTimeoutOrEx_) {
|
||||
broadcastDumpSignal();
|
||||
// Give time for dumping before throwing exception for all ranks.
|
||||
// It is hard to presume or control what the pattern of watchdog might
|
||||
// look like, so it is better to let all ranks universally sleep for a
|
||||
// short period of time, in this case, 60 seconds, which is also the
|
||||
// maximum time we leave for FR dump.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(waitTimeoutDumpInMilSec_ * 4));
|
||||
}
|
||||
broadcastDumpSignal();
|
||||
// Give time for dumping before throwing exception for all ranks.
|
||||
// It is hard to presume or control what the pattern of watchdog might
|
||||
// look like, so it is better to let all ranks universally sleep for a
|
||||
// short period of time, in this case, 60 seconds, which is also the
|
||||
// maximum time we leave for FR dump.
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::milliseconds(waitTimeoutDumpInMilSec_ * 4));
|
||||
|
||||
if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
|
||||
// Abort work and corresponding communicators
|
||||
@ -2452,7 +2490,8 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
completedWorkListCV_.notify_one();
|
||||
} else {
|
||||
it = workMetaList_.erase(it);
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
heartbeatMonitor_->setLastWorkListUpdateTime(
|
||||
std::chrono::steady_clock::now());
|
||||
}
|
||||
} else {
|
||||
// Increment the iterator if the current WorkNCCL object is not
|
||||
@ -3278,7 +3317,8 @@ void ProcessGroupNCCL::workEnqueue(
|
||||
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
|
||||
pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
|
||||
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
heartbeatMonitor_->setLastWorkListUpdateTime(
|
||||
std::chrono::steady_clock::now());
|
||||
}
|
||||
}
|
||||
|
||||
@ -3288,6 +3328,10 @@ ProcessGroupNCCL::Options::Options(bool is_high_priority_stream)
|
||||
|
||||
static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04;
|
||||
|
||||
uint64_t ProcessGroupNCCL::getWatchdogHeartbt() const {
|
||||
return heartbeat_.load();
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::startCoalescing() {
|
||||
// Other collective ops bump seq_ before creating a work. Thus, if coalesced
|
||||
// ops bump seq_ only after initing a work they will collide with (reuse) the
|
||||
|
@ -596,6 +596,86 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::string traceKeyEnd_;
|
||||
};
|
||||
|
||||
// Class that runs as a separate thread aside from watchdog
|
||||
// thread because we need to check the heartbeat from watchdog thread
|
||||
// so that when we get stuck in some NCCL/CUDA calls,
|
||||
// we can dump the debugging information and abort the process.
|
||||
class HeartbeatMonitor {
|
||||
public:
|
||||
HeartbeatMonitor(ProcessGroupNCCL* pg);
|
||||
virtual ~HeartbeatMonitor() = default;
|
||||
|
||||
// Start the heartbeat monitor thread.
|
||||
void start();
|
||||
|
||||
// Join the heartbeat monitor thread.
|
||||
void join();
|
||||
|
||||
// Run the actual loop to check watchdog heartbeat.
|
||||
virtual void runLoop();
|
||||
|
||||
// Set the terminal flag and notify the heartbeat monitor thread to stop.
|
||||
void stop();
|
||||
|
||||
// Set the last update time of watchdog thread.
|
||||
void setLastWorkListUpdateTime(
|
||||
std::chrono::time_point<std::chrono::steady_clock> time);
|
||||
|
||||
// Util function to get the timeout error message
|
||||
std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);
|
||||
|
||||
// Util function to get the timeout exit message
|
||||
std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason);
|
||||
|
||||
protected:
|
||||
// We need to keep a reference to the PG instance so that we can access
|
||||
// the member functions of the PG instance. We store a raw pointer on
|
||||
// purpose because the heartbeat monitor thread now still lives within the
|
||||
// lifetime of the PG instance.
|
||||
ProcessGroupNCCL* pg_;
|
||||
|
||||
private:
|
||||
// Whether or not to print C++ stack traces to logs on unclean shutdown.
|
||||
bool logCppStackOnUncleanShutdown_;
|
||||
|
||||
// The time interval used for deciding whether there is no watchdog
|
||||
// heartbeat.
|
||||
int heartbeatTimeoutInSec_;
|
||||
|
||||
// timeout for the dump to finish.
|
||||
int waitTimeoutDumpInMilSec_;
|
||||
|
||||
// Interval of check coordinated signals in ProcessGroupNCCL from other
|
||||
// ranks e.g., trigger the dump of the debugging info for timeout when
|
||||
// notified.
|
||||
int coordCheckIntervalMilSec_;
|
||||
|
||||
// We gate the heartbeat monitor thread so that we can roll it out
|
||||
// gradually.
|
||||
bool watchdogHeartbeatMonitorEnabled_;
|
||||
|
||||
// Monitor thread which checks the heartbeat of Watchdog thread.
|
||||
// If the monitor thread finds there is no heartbeat, it will dump debug
|
||||
// info and then kill the watchdog thread to avoid hang.
|
||||
std::thread ncclHeartbeatMonitorThread_;
|
||||
|
||||
// Whether or not we should terminate the heartbeat monitoring threads.
|
||||
std::atomic<bool> terminateHeartbeatMonitorThread_{false};
|
||||
|
||||
// Condition Variable for monitor thread to wake up early
|
||||
std::condition_variable monitorWakeUpCV_;
|
||||
|
||||
// Whether or not to dump debug info on exception including both watchdog
|
||||
// timeout and nccl errors.
|
||||
bool dumpOnTimeoutOrEx_;
|
||||
|
||||
// Mutex to Guard monitorWakeUpCV_
|
||||
std::mutex monitorMutex_;
|
||||
|
||||
// The last update time of WorkList inside watchdog thread.
|
||||
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
|
||||
};
|
||||
|
||||
// If you wish to create multiple process groups, each with a potentially
|
||||
// different rank and size, you can do so by passing a new store instance
|
||||
// to each one. If you have only a single store object, you can
|
||||
@ -862,6 +942,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
void setEnableNanCheck(bool enableNanCheck);
|
||||
|
||||
protected:
|
||||
uint64_t getWatchdogHeartbt() const;
|
||||
|
||||
// Instance of the heartbeat monitor thread.
|
||||
std::unique_ptr<HeartbeatMonitor> heartbeatMonitor_;
|
||||
|
||||
// Helper that broadcasts nccl unique ID to all ranks through the store
|
||||
void broadcastUniqueNCCLID(
|
||||
ncclUniqueId* ncclID,
|
||||
@ -1041,6 +1126,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// return the rank_ of the the very first PG created, aka, default global PG.
|
||||
const int& globalRank() const;
|
||||
|
||||
const c10::intrusive_ptr<Store>& globalStore() const;
|
||||
|
||||
// Returns the global ranks of a PG.
|
||||
const std::vector<uint64_t>& groupRanks() const;
|
||||
|
||||
@ -1066,12 +1153,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
const std::string& signal);
|
||||
|
||||
protected:
|
||||
// Function that runs as part of a separate thread aside from watchdog
|
||||
// thread because we need to check the heartbeat from watchdog thread
|
||||
// so that when we get stuck in some NCCL/CUDA calls,
|
||||
// we can dump the debugging information and abort the process.
|
||||
virtual void heartbeatMonitor();
|
||||
|
||||
// Function that directly trigger std::abort so that the whole process
|
||||
// gets terminated.
|
||||
virtual void terminateProcess(const std::string& errMsg);
|
||||
@ -1085,10 +1166,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
::c10d::C10dLoggingData& debugLog,
|
||||
bool throwException = false);
|
||||
|
||||
std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);
|
||||
|
||||
std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason);
|
||||
|
||||
void checkAndSetRemoteError();
|
||||
|
||||
// A helper function to guess the device id of the current rank, based on
|
||||
@ -1171,30 +1248,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Heartbeat of watchdog thread.
|
||||
std::atomic_uint64_t heartbeat_{};
|
||||
|
||||
// The time interval used for deciding whether there is no watchdog heartbeat.
|
||||
int heartbeatTimeoutInSec_;
|
||||
|
||||
// timeout for the dump to finish.
|
||||
int waitTimeoutDumpInMilSec_;
|
||||
|
||||
// Interval of check coordinated signals in ProcessGroupNCCL from other ranks
|
||||
// e.g., trigger the dump of the debugging info for timeout when notified.
|
||||
int coordCheckIntervalMilSec_;
|
||||
|
||||
// Size of ring buffer where we store NCCL Traces for debugging.
|
||||
int traceBufferSize_;
|
||||
|
||||
// We gate the heartbeat monitor thread so that we can roll it out gradually.
|
||||
std::atomic<bool> monitorThreadEnabled_{};
|
||||
|
||||
// We gate the cudaEventCache so that we can roll it out gradually.
|
||||
std::atomic<bool> cudaEventCacheEnabled_{};
|
||||
|
||||
// Monitor thread which checks the heartbeat of Watchdog thread.
|
||||
// If the monitor thread finds there is no heartbeat, it will dump debug info
|
||||
// and then kill the watchdog thread to avoid hang.
|
||||
std::thread ncclHeartbeatMonitorThread_;
|
||||
|
||||
// Watchdog thread which looks for errors on the cached NCCL communicators.
|
||||
std::thread ncclCommWatchdogThread_;
|
||||
|
||||
@ -1203,9 +1265,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether or not we should terminate the watchdog and workCleanup threads.
|
||||
std::atomic<bool> terminateProcessGroup_;
|
||||
|
||||
// Whether or not we should terminate the heartbeat monitoring threads.
|
||||
std::atomic<bool> terminateHeartbeatMonitorThread_;
|
||||
|
||||
// Whether there are hooks pending to be fired
|
||||
std::atomic<bool> hasPendingHooks_{};
|
||||
|
||||
@ -1225,22 +1284,14 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Mutex to Guard workMetaList_
|
||||
std::mutex workMetaListMutex_;
|
||||
|
||||
// Mutex to Guard monitorWakeUpCV_
|
||||
std::mutex monitorMutex_;
|
||||
|
||||
bool writeDebugInfo_ = false;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable workMetaListCV_;
|
||||
|
||||
// Condition Variable for monitor thread to wake up early
|
||||
std::condition_variable monitorWakeUpCV_;
|
||||
|
||||
// Vector to Store WorkNCCL pointers
|
||||
// Vector to store WorkNCCL pointers
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
|
||||
|
||||
std::chrono::time_point<std::chrono::steady_clock> lastWorkListUpdateTime_;
|
||||
|
||||
// Mutex to Guard workMetaList_
|
||||
std::mutex completedWorkListMutex_;
|
||||
|
||||
@ -1302,10 +1353,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
bool desyncDebug_;
|
||||
DesyncDebugger desyncDebugger_;
|
||||
|
||||
// Whether or not to dump debug info on exception including both watchdog
|
||||
// timeout and nccl errors.
|
||||
bool dumpOnTimeoutOrEx_;
|
||||
|
||||
// Whether or not to propagate detected errors to all ranks in the same PG
|
||||
// through TCPStore.
|
||||
bool propagatePgError_;
|
||||
@ -1316,9 +1363,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether or not to enable nan check for input tensors to collectives.
|
||||
bool enableNanCheck_;
|
||||
|
||||
// Whether or not to print C++ stack traces to logs on unclean shutdown.
|
||||
bool logCppStackOnUncleanShutdown_;
|
||||
|
||||
// Whether or not to create start CUDAEvent and enable timing for start
|
||||
// and end events. Note that enableTiming_ is always true if desyncDebug_
|
||||
// is set to true.
|
||||
|
Reference in New Issue
Block a user