mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PTD] Dump rcclexp proxy trace in pytorch (#143678)
Summary: Dump the active proxyOp status per rank and per communicator when WatchDog timeout or aborts. Added `#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP)` guard in the print function, so only rcclexp users will see this dump in console. This is the changes of the PTD. Test Plan: Job with A2A hang due to receiver failing to post receive operations https://fburl.com/mlhub/95vg12r3 {F1971449692} Reviewed By: c-p-i-o Differential Revision: D67036093 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143678 Approved by: https://github.com/c-p-i-o
This commit is contained in:
committed by
PyTorch MergeBot
parent
aa7d01ea22
commit
a881954b0c
@ -377,7 +377,7 @@ std::string NCCLComm::repr() const {
|
||||
return c10::str((void*)ncclComm_);
|
||||
}
|
||||
|
||||
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> NCCLComm::ncclCommDump() {
|
||||
std::unordered_map<std::string, std::string> dump;
|
||||
if (isAborted()) {
|
||||
@ -521,6 +521,17 @@ std::string getNcclErrorDetailStr(
|
||||
return interpret + err;
|
||||
}
|
||||
|
||||
// Dump proxyTrace log to stdout
|
||||
void printNcclCommProxyTrace(
|
||||
std::string dumpReason,
|
||||
const std::unordered_map<std::string, std::string>& dumpMap) {
|
||||
LOG(INFO) << "Dumping nccl comm trace, reason: " << dumpReason;
|
||||
for (auto& [key, value] : dumpMap) {
|
||||
LOG(INFO) << "key: " << key << ", value: " << value;
|
||||
}
|
||||
LOG(INFO) << "----------------------";
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -242,7 +242,7 @@ class NCCLComm {
|
||||
std::vector<uint64_t>& ranks_ull);
|
||||
#endif
|
||||
|
||||
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
||||
@ -356,6 +356,9 @@ struct ncclRedOpRAII {
|
||||
bool premul_sum_ = false;
|
||||
};
|
||||
|
||||
void printNcclCommProxyTrace(
|
||||
std::string dumpReason,
|
||||
const std::unordered_map<std::string, std::string>& dumpMap);
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -347,7 +347,7 @@ static void cacheAllocatorDeregisterHook(
|
||||
static std::
|
||||
unordered_map<std::string, std::unordered_map<std::string, std::string>>
|
||||
getNCCLCommDumpMap() {
|
||||
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<
|
||||
std::string /* ncclUniqueID */,
|
||||
std::unordered_map<std::string, std::string> /* dump from this comm */>
|
||||
@ -380,6 +380,11 @@ std::string dump_nccl_trace(
|
||||
bool includeStackTraces,
|
||||
bool onlyActive) {
|
||||
auto ncclDumpMap = getNCCLCommDumpMap();
|
||||
#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP)
|
||||
for (const auto& [ncclUniqueIDStr, dump] : ncclDumpMap) {
|
||||
printNcclCommProxyTrace("Received dump signal " + ncclUniqueIDStr, dump);
|
||||
}
|
||||
#endif
|
||||
return FlightRecorder::get()->dump(
|
||||
ncclDumpMap, includeCollectives, includeStackTraces, onlyActive);
|
||||
}
|
||||
@ -789,6 +794,12 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::WorkNCCL::abort() {
|
||||
// dump before aborting for rcclexp
|
||||
#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP)
|
||||
auto dumpMap = ncclComm_->ncclCommDump();
|
||||
printNcclCommProxyTrace("WorkNCCL::abort", dumpMap);
|
||||
#endif
|
||||
|
||||
// Abort all communicators of this work
|
||||
ncclComm_->abort();
|
||||
|
||||
|
Reference in New Issue
Block a user