From 1b3d69b59f92383633731aada8383ab88da3ed60 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 8 Jul 2025 23:55:41 +0000 Subject: [PATCH] Work: block_current_stream API (#156883) This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj --- build_variables.bzl | 2 + test/distributed/test_c10d_gloo.py | 15 ++++ test/distributed/test_c10d_nccl.py | 15 ++++ test/distributed/test_c10d_pypg.py | 34 +++++++++ torch/_C/_distributed_c10d.pyi | 3 +- .../distributed/c10d/ProcessGroupNCCL.hpp | 4 ++ torch/csrc/distributed/c10d/Work.cpp | 10 +++ torch/csrc/distributed/c10d/Work.hpp | 7 ++ .../distributed/c10d/cuda/StreamBlock.cpp | 14 ++++ .../csrc/distributed/c10d/cuda/StreamBlock.cu | 69 +++++++++++++++++++ .../distributed/c10d/cuda/StreamBlock.cuh | 29 ++++++++ .../distributed/c10d/cuda/StreamBlock.hpp | 38 ++++++++++ torch/csrc/distributed/c10d/init.cpp | 15 ++++ 13 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/distributed/c10d/cuda/StreamBlock.cpp create mode 100644 torch/csrc/distributed/c10d/cuda/StreamBlock.cu create mode 100644 torch/csrc/distributed/c10d/cuda/StreamBlock.cuh create mode 100644 torch/csrc/distributed/c10d/cuda/StreamBlock.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 51854e7c9000..99290d5318cd 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -518,6 +518,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", + "torch/csrc/distributed/c10d/cuda/StreamBlock.cpp", "torch/csrc/distributed/c10d/debug.cpp", "torch/csrc/distributed/c10d/default_comm_hooks.cpp", "torch/csrc/distributed/c10d/logger.cpp", @@ -734,6 +735,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", "torch/csrc/distributed/c10d/cuda/utils.cpp", + "torch/csrc/distributed/c10d/cuda/StreamBlock.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 96ad01b95b18..073c69f8dd8c 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -1562,6 +1562,21 @@ class ProcessGroupGlooTest(MultiProcessTestCase): for i, tensor in enumerate(tensors): self.assertEqual(torch.full(size, float(i * self.world_size)), tensor) + @skip_if_lt_x_gpu(2) + @requires_gloo() + @skipIfRocm + def test_block_current_stream_cuda(self): + store = c10d.FileStore(self.file_name, self.world_size) + pg = self._create_process_group_gloo( + store, self.rank, self.world_size, self.opts() + ) + t = torch.zeros(10, device="cuda") + work = pg.allreduce(t) + work.block_current_stream() + torch.cuda.current_stream().synchronize() + + work.wait() + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c02e968e23fb..b2efcdbda03e 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): ) dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx))) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_block_current_stream(self): + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + pg = self._create_process_group_nccl(store, self.opts(), device_id=device) + + t = torch.rand(10, device=device) + work = pg.allreduce(t) + work.block_current_stream() + + torch.cuda.current_stream().synchronize() + work.wait() + torch.cuda.synchronize() + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py index e516c8c94d32..6ccb81b116ca 100644 --- a/test/distributed/test_c10d_pypg.py +++ b/test/distributed/test_c10d_pypg.py @@ -1,5 +1,7 @@ # Owner(s): ["oncall: distributed"] +import time +import unittest import weakref import test_c10d_common @@ -10,6 +12,7 @@ import torch.nn as nn from torch._C._distributed_c10d import _create_work_from_future from torch.futures import Future from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import MultiThreadedTestCase from torch.testing._internal.common_utils import run_tests, TestCase @@ -197,6 +200,37 @@ class TestPyProcessGroup(TestCase): pg.abort() pg.shutdown() + @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") + def test_block_current_stream(self) -> None: + class BlockWork(dist._Work): + def __init__(self): + super().__init__() + self.future_ = torch.futures.Future() + + def get_future(self): + return self.future_ + + # nothing in queue so instantly resolves + event1 = torch.cuda.Event() + event1.record() + time.sleep(0.1) + self.assertTrue(event1.query()) + + work = BlockWork() + work.block_current_stream() + + # stream is blocked so doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # resolve the work + work.get_future().set_result(None) + + torch.cuda.current_stream().synchronize() + self.assertTrue(event.query()) + if __name__ == "__main__": run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index d145ed7ce653..e006c35bc90d 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -279,11 +279,12 @@ class Work: def is_success(self) -> bool: ... def exception(self) -> Any: ... def wait(self, timeout: timedelta = ...) -> bool: ... + def block_current_stream(self) -> None: ... def get_future(self) -> Future: ... def source_rank(self) -> int: ... def _source_rank(self) -> int: ... def result(self) -> list[Tensor]: ... - def synchronize(self): ... + def synchronize(self) -> None: ... def boxed(self) -> ScriptObject: ... @staticmethod def unbox(obj: ScriptObject) -> Work: ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 104357fb1b38..4353b878df12 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -349,6 +349,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // or timed out. If timeout, exception will be thrown. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + void blockCurrentStream() override { + synchronize(); + } + void abort() override; // Let current stream wait on the completion of the NCCL work diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 5e30e91ce05b..cdec9185ce53 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -100,6 +101,15 @@ bool Work::wait(std::chrono::milliseconds timeout) { return true; } +void Work::blockCurrentStream() { + // block cuda stream indefinitely until work is completed. + std::shared_ptr handle = + c10d::cuda::block_stream(std::chrono::milliseconds(0)); + + getFuture()->addCallback( + [handle](c10::ivalue::Future& future) { handle->abort(); }); +} + void Work::abort() { TORCH_CHECK(false, "Work::abort not implemented."); } diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index e9e785a9c643..3b743e36d2a0 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -110,6 +110,13 @@ class TORCH_API Work : public torch::CustomClassHolder { // virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); + // Blocks the current stream until the work is completed. + // This is equivalent to synchronize for CUDA tensors but works for both CPU + // tensors and CUDA tensors by using a spinlock CUDA kernel. + // This will immediately return. + // If no stream is active it will throw an error. + virtual void blockCurrentStream(); + virtual void abort(); // Returns a Future object that will be associated with the completion of diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cpp b/torch/csrc/distributed/c10d/cuda/StreamBlock.cpp new file mode 100644 index 000000000000..5bc846243776 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cpp @@ -0,0 +1,14 @@ +#include +#include + +namespace c10d::cuda { + +C10_DEFINE_REGISTRY(StreamBlockRegistry, StreamBlock, std::chrono::milliseconds) + +std::unique_ptr block_stream(std::chrono::milliseconds timeout) { + auto baton = StreamBlockRegistry()->Create("CUDA", timeout); + TORCH_CHECK(baton, "Failed to create StreamBlock"); + return baton; +} + +} // namespace c10d::cuda diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu new file mode 100644 index 000000000000..58533ece6af8 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +namespace c10d::cuda::detail { + +__device__ void nanosleep(int64_t ns) { + // This is a noop on pre-CUDA-7.0 and ROCm devices and effectively falls back + // to a spinlock. This only can sleep for a max of 1ms on CUDA devices. +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(ns); +#endif +} + +__global__ +// set launch bounds to limit to 1 thread per block, 1 block per MP +__launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { + value[1] = StreamBlockStatus::RUNNING; + + size_t start = c10d::symmetric_memory::global_timer_ns(); + size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds + while (true) { + // Atomically read the value + int current_value = atomicAdd(&value[0], 0); + // Check if the value is equal to the expected value + if (current_value == 1) { + value[1] = StreamBlockStatus::ABORTED; + return; + } + + if (timeout_ms > 0) { + // Check if timeout has been reached + size_t now = c10d::symmetric_memory::global_timer_ns(); + if ((now - start) > timeout_ns) { + value[1] = StreamBlockStatus::TIMED_OUT; + return; + } + } + + // sleep for 1ms + nanosleep(1000000); + } +} + +StreamBlock::StreamBlock(std::chrono::milliseconds timeout) + : comm_{ + // We need to pin the memory since we access the CPU memory directly form + // the GPU. + at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() + }, + timeout_{timeout} { + // grid size 1, block size 1, 0 bytes of shared memory + kernel_barrier<<<1, 1, 0>>>( + comm_.mutable_data_ptr(), timeout_.count()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock) + +} // namespace c10d::cuda::detail diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh new file mode 100644 index 000000000000..f94f272d7eef --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include + +#include + +namespace c10d::cuda::detail { + +class StreamBlock : public ::c10d::cuda::StreamBlock { + public: + StreamBlock(std::chrono::milliseconds timeout); + + void abort() override { + comm_[0] = 1; + } + + StreamBlockStatus status() override { + return static_cast(comm_[1].item()); + } + + private: + // (abort, cycles) + const at::Tensor comm_; + const std::chrono::milliseconds timeout_; +}; + +} // namespace c10d::cuda::detail diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.hpp b/torch/csrc/distributed/c10d/cuda/StreamBlock.hpp new file mode 100644 index 000000000000..391a82e58708 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#include + +namespace c10d::cuda { + +enum StreamBlockStatus : int32_t { + UNKNOWN = 0, + RUNNING = 1, + TIMED_OUT = 2, + ABORTED = 3, +}; + +/* +StreamBlock implements a baton that will block a the active CUDA stream +until aborted by the main process. +*/ +class TORCH_API StreamBlock { + public: + virtual ~StreamBlock() = default; + virtual void abort() = 0; + virtual StreamBlockStatus status() = 0; +}; + +std::unique_ptr block_stream(std::chrono::milliseconds timeout); + +// Declare a registry so we can call the CUDA StreamBlock API from CPU only code +// (i.e. ProcessGroup/Work objects in libtorch_cpu). +// The implementation lives defined in StreamBlock.cu. +TORCH_DECLARE_REGISTRY( + StreamBlockRegistry, + StreamBlock, + std::chrono::milliseconds); + +} // namespace c10d::cuda diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6f5ae85ca9a6..617499e858a1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown. )") + .def( + "block_current_stream", + &::c10d::Work::blockCurrentStream, + py::call_guard(), + R"( + Blocks the currently active GPU stream on the operation to + complete. For GPU based collectives this is equivalent to + synchronize. For CPU initiated collectives such as with Gloo this + will block the CUDA stream until the operation is complete. + + This returns immediately in all cases. + + To check whether an operation was successful you should check the + Work object result asynchronously. + )") .def( "get_future_result", [](::c10d::Work& work)