diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 19aa6f09d8d2..b1ede292c0a3 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -377,7 +377,7 @@ std::string NCCLComm::repr() const { return c10::str((void*)ncclComm_); } -#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) +#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map NCCLComm::ncclCommDump() { std::unordered_map dump; if (isAborted()) { @@ -521,6 +521,17 @@ std::string getNcclErrorDetailStr( return interpret + err; } +// Dump proxyTrace log to stdout +void printNcclCommProxyTrace( + std::string dumpReason, + const std::unordered_map& dumpMap) { + LOG(INFO) << "Dumping nccl comm trace, reason: " << dumpReason; + for (auto& [key, value] : dumpMap) { + LOG(INFO) << "key: " << key << ", value: " << value; + } + LOG(INFO) << "----------------------"; +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index ffb3a1f3dca0..3d7d5584e38f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -242,7 +242,7 @@ class NCCLComm { std::vector& ranks_ull); #endif -#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) +#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif @@ -356,6 +356,9 @@ struct ncclRedOpRAII { bool premul_sum_ = false; }; +void printNcclCommProxyTrace( + std::string dumpReason, + const std::unordered_map& dumpMap); } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index df8de61474a6..b05a2b9b3680 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -347,7 +347,7 @@ static void cacheAllocatorDeregisterHook( static std:: unordered_map> getNCCLCommDumpMap() { -#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) +#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map< std::string /* ncclUniqueID */, std::unordered_map /* dump from this comm */> @@ -380,6 +380,11 @@ std::string dump_nccl_trace( bool includeStackTraces, bool onlyActive) { auto ncclDumpMap = getNCCLCommDumpMap(); +#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP) + for (const auto& [ncclUniqueIDStr, dump] : ncclDumpMap) { + printNcclCommProxyTrace("Received dump signal " + ncclUniqueIDStr, dump); + } +#endif return FlightRecorder::get()->dump( ncclDumpMap, includeCollectives, includeStackTraces, onlyActive); } @@ -789,6 +794,12 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { } void ProcessGroupNCCL::WorkNCCL::abort() { + // dump before aborting for rcclexp +#if defined(USE_ROCM) && defined(NCCL_COMM_DUMP) + auto dumpMap = ncclComm_->ncclCommDump(); + printNcclCommProxyTrace("WorkNCCL::abort", dumpMap); +#endif + // Abort all communicators of this work ncclComm_->abort();