Work: block_current_stream API (#156883)

This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo.

The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP.

There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780

This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set.

Test plan:

```
pytest test/distributed/test_c10d_gloo.py -v -k wait_stream
pytest test/distributed/test_c10d_nccl.py -v -k wait_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-08 23:55:41 +00:00
committed by PyTorch MergeBot
parent 92f41ccc26
commit 1b3d69b59f
13 changed files with 254 additions and 1 deletions

View File

@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
)
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_block_current_stream(self):
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
t = torch.rand(10, device=device)
work = pg.allreduce(t)
work.block_current_stream()
torch.cuda.current_stream().synchronize()
work.wait()
torch.cuda.synchronize()
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase