Revert "[NCCL][CUDA][CUDA Graphs] Flush enqueued work before starting a graph capture (#104487)"

This reverts commit db63bf3d7e5eef320dde9c2d4b7976eb5fcddbd6.

Reverted https://github.com/pytorch/pytorch/pull/104487 on behalf of https://github.com/huydhn due to Sorry for reverting you change, it is failing internal build ([comment](https://github.com/pytorch/pytorch/pull/104487#issuecomment-1707055346))
This commit is contained in:
PyTorch MergeBot
2023-09-05 17:57:16 +00:00
parent 29f1097891
commit 5b31a41841
3 changed files with 0 additions and 54 deletions

View File

@ -4,7 +4,6 @@
#include <ATen/Functions.h> #include <ATen/Functions.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h> #include <c10/cuda/CUDAFunctions.h>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
namespace at::cuda { namespace at::cuda {
@ -116,12 +115,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
// due to the capture status being updated _after_ a capture had already started. // due to the capture status being updated _after_ a capture had already started.
c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_); c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_);
#ifdef USE_C10D_NCCL
// If the watchdog has remaining work enqueued, an event query on the remaining work will crash
// the graph capture, so we wait for all pending work to be completed.
c10d::ProcessGroupNCCL::waitForAllPendingWorks();
#endif
// cudaStreamCaptureModeGlobal is the most conservative option to // cudaStreamCaptureModeGlobal is the most conservative option to
// prevent potentially unsafe CUDA API calls during capture. See // prevent potentially unsafe CUDA API calls during capture. See
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85

View File

@ -286,9 +286,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000; const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
constexpr int64_t kSynchronizeBusyWaitMillis = 10; constexpr int64_t kSynchronizeBusyWaitMillis = 10;
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
std::mutex ProcessGroupNCCL::allProcessGroupsMutex_;
std::unordered_set<c10d::ProcessGroupNCCL*>
ProcessGroupNCCL::all_nccl_process_groups;
std::ostream& operator<<( std::ostream& operator<<(
std::ostream& output, std::ostream& output,
@ -731,10 +728,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
} }
} }
#endif #endif
{
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
all_nccl_process_groups.insert(this);
}
} }
void ProcessGroupNCCL::runHealthCheck() { void ProcessGroupNCCL::runHealthCheck() {
@ -904,11 +897,6 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
// Abort all NCCL Communicators on Process Group Destruction // Abort all NCCL Communicators on Process Group Destruction
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_); std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
abort(abortReason); abort(abortReason);
{
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
all_nccl_process_groups.erase(this);
}
} }
void ProcessGroupNCCL::ncclCommWatchdog() { void ProcessGroupNCCL::ncclCommWatchdog() {
@ -1101,15 +1089,6 @@ void ProcessGroupNCCL::runHookLoop() {
} }
} }
void ProcessGroupNCCL::waitForAllPendingWorks() {
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
for (auto it = ProcessGroupNCCL::all_nccl_process_groups.begin();
it != ProcessGroupNCCL::all_nccl_process_groups.end();
it++) {
(*it)->waitForPendingWorks();
}
}
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors( std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const { const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
return checkForNCCLErrorsInternal(ncclComms); return checkForNCCLErrorsInternal(ncclComms);
@ -1670,13 +1649,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
c10::cuda::CaptureStatus capture_status = c10::cuda::CaptureStatus capture_status =
c10::cuda::currentStreamCaptureStatusMayInitCtx(); c10::cuda::currentStreamCaptureStatusMayInitCtx();
if (capture_status != c10::cuda::CaptureStatus::None) {
std::lock_guard<std::mutex> lock(workMetaListMutex_);
TORCH_INTERNAL_ASSERT(
workMetaList_.empty(),
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
}
if ((coalescing_state_ & CoalColl) && if ((coalescing_state_ & CoalColl) &&
capture_status == c10::cuda::CaptureStatus::None) { capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work); workEnqueue(work);
@ -1851,13 +1823,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
work->numelIn_ = inputs[0].numel(); work->numelIn_ = inputs[0].numel();
work->numelOut_ = outputs[0].numel(); work->numelOut_ = outputs[0].numel();
if (capture_status != c10::cuda::CaptureStatus::None) {
std::lock_guard<std::mutex> lock(workMetaListMutex_);
TORCH_INTERNAL_ASSERT(
workMetaList_.empty(),
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
}
if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) { if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
workEnqueue(work); workEnqueue(work);
} }

View File

@ -8,7 +8,6 @@
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp> #include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
@ -346,11 +345,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
~ProcessGroupNCCL() override; ~ProcessGroupNCCL() override;
// Check that all work is done (no enqueued work).
// We use this to avoid uwittingly having watchdogs query work during
// CUDA graph captures.
static void waitForAllPendingWorks();
c10::intrusive_ptr<Options> getOptions() { c10::intrusive_ptr<Options> getOptions() {
return options_; return options_;
} }
@ -695,9 +689,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Mutex to Guard workMetaList_ // Mutex to Guard workMetaList_
std::mutex workMetaListMutex_; std::mutex workMetaListMutex_;
// Mutex to Guard all_nccl_process_groups
static std::mutex allProcessGroupsMutex_;
// Condition Variable for watchdog thread sleep // Condition Variable for watchdog thread sleep
std::condition_variable workMetaListCV_; std::condition_variable workMetaListCV_;
@ -712,9 +703,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_; std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
// All process groups for checking Watchdog status
static std::unordered_set<c10d::ProcessGroupNCCL*> all_nccl_process_groups;
// Add Work Pointer to workVector // Add Work Pointer to workVector
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>); void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);