mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Revert "[NCCL][CUDA][CUDA Graphs] Flush enqueued work before starting a graph capture (#104487)"
This reverts commit db63bf3d7e5eef320dde9c2d4b7976eb5fcddbd6. Reverted https://github.com/pytorch/pytorch/pull/104487 on behalf of https://github.com/huydhn due to Sorry for reverting you change, it is failing internal build ([comment](https://github.com/pytorch/pytorch/pull/104487#issuecomment-1707055346))
This commit is contained in:
@ -4,7 +4,6 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
@ -116,12 +115,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
||||
// due to the capture status being updated _after_ a capture had already started.
|
||||
c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_);
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
// If the watchdog has remaining work enqueued, an event query on the remaining work will crash
|
||||
// the graph capture, so we wait for all pending work to be completed.
|
||||
c10d::ProcessGroupNCCL::waitForAllPendingWorks();
|
||||
#endif
|
||||
|
||||
// cudaStreamCaptureModeGlobal is the most conservative option to
|
||||
// prevent potentially unsafe CUDA API calls during capture. See
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||
|
@ -286,9 +286,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
|
||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
|
||||
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
||||
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
|
||||
std::mutex ProcessGroupNCCL::allProcessGroupsMutex_;
|
||||
std::unordered_set<c10d::ProcessGroupNCCL*>
|
||||
ProcessGroupNCCL::all_nccl_process_groups;
|
||||
|
||||
std::ostream& operator<<(
|
||||
std::ostream& output,
|
||||
@ -731,10 +728,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
}
|
||||
}
|
||||
#endif
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
||||
all_nccl_process_groups.insert(this);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::runHealthCheck() {
|
||||
@ -904,11 +897,6 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
||||
// Abort all NCCL Communicators on Process Group Destruction
|
||||
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
|
||||
abort(abortReason);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
||||
all_nccl_process_groups.erase(this);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||
@ -1101,15 +1089,6 @@ void ProcessGroupNCCL::runHookLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::waitForAllPendingWorks() {
|
||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
||||
for (auto it = ProcessGroupNCCL::all_nccl_process_groups.begin();
|
||||
it != ProcessGroupNCCL::all_nccl_process_groups.end();
|
||||
it++) {
|
||||
(*it)->waitForPendingWorks();
|
||||
}
|
||||
}
|
||||
|
||||
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
|
||||
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
|
||||
return checkForNCCLErrorsInternal(ncclComms);
|
||||
@ -1670,13 +1649,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
|
||||
c10::cuda::CaptureStatus capture_status =
|
||||
c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
||||
|
||||
if (capture_status != c10::cuda::CaptureStatus::None) {
|
||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
workMetaList_.empty(),
|
||||
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
|
||||
}
|
||||
|
||||
if ((coalescing_state_ & CoalColl) &&
|
||||
capture_status == c10::cuda::CaptureStatus::None) {
|
||||
workEnqueue(work);
|
||||
@ -1851,13 +1823,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
work->numelIn_ = inputs[0].numel();
|
||||
work->numelOut_ = outputs[0].numel();
|
||||
|
||||
if (capture_status != c10::cuda::CaptureStatus::None) {
|
||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
workMetaList_.empty(),
|
||||
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
|
||||
}
|
||||
|
||||
if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
|
||||
workEnqueue(work);
|
||||
}
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
@ -346,11 +345,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
~ProcessGroupNCCL() override;
|
||||
|
||||
// Check that all work is done (no enqueued work).
|
||||
// We use this to avoid uwittingly having watchdogs query work during
|
||||
// CUDA graph captures.
|
||||
static void waitForAllPendingWorks();
|
||||
|
||||
c10::intrusive_ptr<Options> getOptions() {
|
||||
return options_;
|
||||
}
|
||||
@ -695,9 +689,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Mutex to Guard workMetaList_
|
||||
std::mutex workMetaListMutex_;
|
||||
|
||||
// Mutex to Guard all_nccl_process_groups
|
||||
static std::mutex allProcessGroupsMutex_;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable workMetaListCV_;
|
||||
|
||||
@ -712,9 +703,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
||||
|
||||
// All process groups for checking Watchdog status
|
||||
static std::unordered_set<c10d::ProcessGroupNCCL*> all_nccl_process_groups;
|
||||
|
||||
// Add Work Pointer to workVector
|
||||
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
|
||||
|
||||
|
Reference in New Issue
Block a user