Work: block_current_stream API (#156883)

This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo.

The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP.

There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780

This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set.

Test plan:

```
pytest test/distributed/test_c10d_gloo.py -v -k wait_stream
pytest test/distributed/test_c10d_nccl.py -v -k wait_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-08 23:55:41 +00:00
committed by PyTorch MergeBot
parent 92f41ccc26
commit 1b3d69b59f
13 changed files with 254 additions and 1 deletions

View File

@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`.
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
or timed out. If timeout, exception will be thrown.
)")
.def(
"block_current_stream",
&::c10d::Work::blockCurrentStream,
py::call_guard<py::gil_scoped_release>(),
R"(
Blocks the currently active GPU stream on the operation to
complete. For GPU based collectives this is equivalent to
synchronize. For CPU initiated collectives such as with Gloo this
will block the CUDA stream until the operation is complete.
This returns immediately in all cases.
To check whether an operation was successful you should check the
Work object result asynchronously.
)")
.def(
"get_future_result",
[](::c10d::Work& work)