Work: block_current_stream API (#156883)

This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo.

The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP.

There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780

This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set.

Test plan:

```
pytest test/distributed/test_c10d_gloo.py -v -k wait_stream
pytest test/distributed/test_c10d_nccl.py -v -k wait_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-08 23:55:41 +00:00
committed by PyTorch MergeBot
parent 92f41ccc26
commit 1b3d69b59f
13 changed files with 254 additions and 1 deletions

View File

@ -518,6 +518,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
"torch/csrc/distributed/c10d/cuda/StreamBlock.cpp",
"torch/csrc/distributed/c10d/debug.cpp",
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
"torch/csrc/distributed/c10d/logger.cpp",
@ -734,6 +735,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/cuda/utils.cpp",
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",