mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "[NCCL][CUDA][CUDA Graphs] Flush enqueued work before starting a graph capture (#104487)"
This reverts commit db63bf3d7e5eef320dde9c2d4b7976eb5fcddbd6. Reverted https://github.com/pytorch/pytorch/pull/104487 on behalf of https://github.com/huydhn due to Sorry for reverting you change, it is failing internal build ([comment](https://github.com/pytorch/pytorch/pull/104487#issuecomment-1707055346))
This commit is contained in:
@ -4,7 +4,6 @@
|
|||||||
#include <ATen/Functions.h>
|
#include <ATen/Functions.h>
|
||||||
#include <c10/cuda/CUDACachingAllocator.h>
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
#include <c10/cuda/CUDAFunctions.h>
|
#include <c10/cuda/CUDAFunctions.h>
|
||||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
||||||
|
|
||||||
namespace at::cuda {
|
namespace at::cuda {
|
||||||
|
|
||||||
@ -116,12 +115,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt
|
|||||||
// due to the capture status being updated _after_ a capture had already started.
|
// due to the capture status being updated _after_ a capture had already started.
|
||||||
c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_);
|
c10::cuda::CUDACachingAllocator::beginAllocateStreamToPool(capture_dev_, capture_stream_, mempool_id_);
|
||||||
|
|
||||||
#ifdef USE_C10D_NCCL
|
|
||||||
// If the watchdog has remaining work enqueued, an event query on the remaining work will crash
|
|
||||||
// the graph capture, so we wait for all pending work to be completed.
|
|
||||||
c10d::ProcessGroupNCCL::waitForAllPendingWorks();
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// cudaStreamCaptureModeGlobal is the most conservative option to
|
// cudaStreamCaptureModeGlobal is the most conservative option to
|
||||||
// prevent potentially unsafe CUDA API calls during capture. See
|
// prevent potentially unsafe CUDA API calls during capture. See
|
||||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85
|
||||||
|
@ -286,9 +286,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
|
|||||||
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
|
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 1000;
|
||||||
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
|
||||||
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
|
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
|
||||||
std::mutex ProcessGroupNCCL::allProcessGroupsMutex_;
|
|
||||||
std::unordered_set<c10d::ProcessGroupNCCL*>
|
|
||||||
ProcessGroupNCCL::all_nccl_process_groups;
|
|
||||||
|
|
||||||
std::ostream& operator<<(
|
std::ostream& operator<<(
|
||||||
std::ostream& output,
|
std::ostream& output,
|
||||||
@ -731,10 +728,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
|
||||||
all_nccl_process_groups.insert(this);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::runHealthCheck() {
|
void ProcessGroupNCCL::runHealthCheck() {
|
||||||
@ -904,11 +897,6 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
|
|||||||
// Abort all NCCL Communicators on Process Group Destruction
|
// Abort all NCCL Communicators on Process Group Destruction
|
||||||
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
|
std::string abortReason = c10::str("Process Group destroyed on rank ", rank_);
|
||||||
abort(abortReason);
|
abort(abortReason);
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
|
||||||
all_nccl_process_groups.erase(this);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::ncclCommWatchdog() {
|
void ProcessGroupNCCL::ncclCommWatchdog() {
|
||||||
@ -1101,15 +1089,6 @@ void ProcessGroupNCCL::runHookLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ProcessGroupNCCL::waitForAllPendingWorks() {
|
|
||||||
std::lock_guard<std::mutex> lk(allProcessGroupsMutex_);
|
|
||||||
for (auto it = ProcessGroupNCCL::all_nccl_process_groups.begin();
|
|
||||||
it != ProcessGroupNCCL::all_nccl_process_groups.end();
|
|
||||||
it++) {
|
|
||||||
(*it)->waitForPendingWorks();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
|
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
|
||||||
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
|
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
|
||||||
return checkForNCCLErrorsInternal(ncclComms);
|
return checkForNCCLErrorsInternal(ncclComms);
|
||||||
@ -1670,13 +1649,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
|
|||||||
c10::cuda::CaptureStatus capture_status =
|
c10::cuda::CaptureStatus capture_status =
|
||||||
c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
||||||
|
|
||||||
if (capture_status != c10::cuda::CaptureStatus::None) {
|
|
||||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
workMetaList_.empty(),
|
|
||||||
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((coalescing_state_ & CoalColl) &&
|
if ((coalescing_state_ & CoalColl) &&
|
||||||
capture_status == c10::cuda::CaptureStatus::None) {
|
capture_status == c10::cuda::CaptureStatus::None) {
|
||||||
workEnqueue(work);
|
workEnqueue(work);
|
||||||
@ -1851,13 +1823,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
|||||||
work->numelIn_ = inputs[0].numel();
|
work->numelIn_ = inputs[0].numel();
|
||||||
work->numelOut_ = outputs[0].numel();
|
work->numelOut_ = outputs[0].numel();
|
||||||
|
|
||||||
if (capture_status != c10::cuda::CaptureStatus::None) {
|
|
||||||
std::lock_guard<std::mutex> lock(workMetaListMutex_);
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
workMetaList_.empty(),
|
|
||||||
"In the middle of a CUDA Graph capture but the enqueued work is not empty. The watchdog will crash the capture when it polls the work.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
|
if (!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None) {
|
||||||
workEnqueue(work);
|
workEnqueue(work);
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||||
@ -346,11 +345,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
|
|
||||||
~ProcessGroupNCCL() override;
|
~ProcessGroupNCCL() override;
|
||||||
|
|
||||||
// Check that all work is done (no enqueued work).
|
|
||||||
// We use this to avoid uwittingly having watchdogs query work during
|
|
||||||
// CUDA graph captures.
|
|
||||||
static void waitForAllPendingWorks();
|
|
||||||
|
|
||||||
c10::intrusive_ptr<Options> getOptions() {
|
c10::intrusive_ptr<Options> getOptions() {
|
||||||
return options_;
|
return options_;
|
||||||
}
|
}
|
||||||
@ -695,9 +689,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
// Mutex to Guard workMetaList_
|
// Mutex to Guard workMetaList_
|
||||||
std::mutex workMetaListMutex_;
|
std::mutex workMetaListMutex_;
|
||||||
|
|
||||||
// Mutex to Guard all_nccl_process_groups
|
|
||||||
static std::mutex allProcessGroupsMutex_;
|
|
||||||
|
|
||||||
// Condition Variable for watchdog thread sleep
|
// Condition Variable for watchdog thread sleep
|
||||||
std::condition_variable workMetaListCV_;
|
std::condition_variable workMetaListCV_;
|
||||||
|
|
||||||
@ -712,9 +703,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
|
|
||||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList_;
|
||||||
|
|
||||||
// All process groups for checking Watchdog status
|
|
||||||
static std::unordered_set<c10d::ProcessGroupNCCL*> all_nccl_process_groups;
|
|
||||||
|
|
||||||
// Add Work Pointer to workVector
|
// Add Work Pointer to workVector
|
||||||
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
|
void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user