mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
15 lines
440 B
C++
15 lines
440 B
C++
#include <c10/util/Exception.h>
|
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
|
|
|
namespace c10d::cuda {
|
|
|
|
C10_DEFINE_REGISTRY(StreamBlockRegistry, StreamBlock, std::chrono::milliseconds)
|
|
|
|
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout) {
|
|
auto baton = StreamBlockRegistry()->Create("CUDA", timeout);
|
|
TORCH_CHECK(baton, "Failed to create StreamBlock");
|
|
return baton;
|
|
}
|
|
|
|
} // namespace c10d::cuda
|