From 2d0cdee394bccadcd0abe19dd4623ed978a331ad Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 12 Aug 2025 19:25:04 +0000 Subject: [PATCH] move thread-local capture mode guard to include work.isStarted (#160398) Per title, should fix capture errors that happen because nccl watchdog races with capture start. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160398 Approved by: https://github.com/aorenste --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3cb6aee8b9df..3e9802d855e7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2284,6 +2284,10 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // Work status logging for desync debug desyncDebugger_.logWorkStart(work); + // allow watchdog to do an event query on a side thread + at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); + at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; + // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -2295,10 +2299,6 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; } - // allow watchdog to do an event query on a side thread - at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); - at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - // Clean up completed work if (work.isCompleted()) { // In case user didn't call `work.wait()` with async collectives,