mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Work: block_current_stream API (#156883)
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
92f41ccc26
commit
1b3d69b59f
@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
|
||||
or timed out. If timeout, exception will be thrown.
|
||||
)")
|
||||
.def(
|
||||
"block_current_stream",
|
||||
&::c10d::Work::blockCurrentStream,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Blocks the currently active GPU stream on the operation to
|
||||
complete. For GPU based collectives this is equivalent to
|
||||
synchronize. For CPU initiated collectives such as with Gloo this
|
||||
will block the CUDA stream until the operation is complete.
|
||||
|
||||
This returns immediately in all cases.
|
||||
|
||||
To check whether an operation was successful you should check the
|
||||
Work object result asynchronously.
|
||||
)")
|
||||
.def(
|
||||
"get_future_result",
|
||||
[](::c10d::Work& work)
|
||||
|
Reference in New Issue
Block a user