move thread-local capture mode guard to include work.isStarted (#160398)

Per title, should fix capture errors that happen because nccl watchdog races with capture start.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160398
Approved by: https://github.com/aorenste
This commit is contained in:
Natalia Gimelshein
2025-08-12 19:25:04 +00:00
committed by PyTorch MergeBot
parent 9903ca4f70
commit 2d0cdee394

View File

@ -2284,6 +2284,10 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
// Work status logging for desync debug // Work status logging for desync debug
desyncDebugger_.logWorkStart(work); desyncDebugger_.logWorkStart(work);
// allow watchdog to do an event query on a side thread
at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index());
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
// a work could be started but not completed, so we should not update // a work could be started but not completed, so we should not update
// lastStartedSeq and lastStartedOpName if the work state is checked // lastStartedSeq and lastStartedOpName if the work state is checked
// multiple times after the start // multiple times after the start
@ -2295,10 +2299,6 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastStartedNumelOut = work.numelOut_; pg_->pgStatus_->lastStartedNumelOut = work.numelOut_;
} }
// allow watchdog to do an event query on a side thread
at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index());
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal};
// Clean up completed work // Clean up completed work
if (work.isCompleted()) { if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives, // In case user didn't call `work.wait()` with async collectives,