Work: block_current_stream API (#156883)

This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo.

The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP.

There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780

This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set.

Test plan:

```
pytest test/distributed/test_c10d_gloo.py -v -k wait_stream
pytest test/distributed/test_c10d_nccl.py -v -k wait_stream
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883
Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
Tristan Rice
2025-07-08 23:55:41 +00:00
committed by PyTorch MergeBot
parent 92f41ccc26
commit 1b3d69b59f
13 changed files with 254 additions and 1 deletions

View File

@ -518,6 +518,7 @@ libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
"torch/csrc/distributed/c10d/cuda/StreamBlock.cpp",
"torch/csrc/distributed/c10d/debug.cpp",
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
"torch/csrc/distributed/c10d/logger.cpp",
@ -734,6 +735,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/UCCUtils.cpp",
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
"torch/csrc/distributed/c10d/cuda/utils.cpp",
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",

View File

@ -1562,6 +1562,21 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
for i, tensor in enumerate(tensors):
self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
@skip_if_lt_x_gpu(2)
@requires_gloo()
@skipIfRocm
def test_block_current_stream_cuda(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_gloo(
store, self.rank, self.world_size, self.opts()
)
t = torch.zeros(10, device="cuda")
work = pg.allreduce(t)
work.block_current_stream()
torch.cuda.current_stream().synchronize()
work.wait()
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase

View File

@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
)
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_block_current_stream(self):
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f"cuda:{self.rank}")
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
t = torch.rand(10, device=device)
work = pg.allreduce(t)
work.block_current_stream()
torch.cuda.current_stream().synchronize()
work.wait()
torch.cuda.synchronize()
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase

View File

@ -1,5 +1,7 @@
# Owner(s): ["oncall: distributed"]
import time
import unittest
import weakref
import test_c10d_common
@ -10,6 +12,7 @@ import torch.nn as nn
from torch._C._distributed_c10d import _create_work_from_future
from torch.futures import Future
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import MultiThreadedTestCase
from torch.testing._internal.common_utils import run_tests, TestCase
@ -197,6 +200,37 @@ class TestPyProcessGroup(TestCase):
pg.abort()
pg.shutdown()
@unittest.skipIf(not TEST_CUDA, "no cuda/xpu")
def test_block_current_stream(self) -> None:
class BlockWork(dist._Work):
def __init__(self):
super().__init__()
self.future_ = torch.futures.Future()
def get_future(self):
return self.future_
# nothing in queue so instantly resolves
event1 = torch.cuda.Event()
event1.record()
time.sleep(0.1)
self.assertTrue(event1.query())
work = BlockWork()
work.block_current_stream()
# stream is blocked so doesn't resolve
event = torch.cuda.Event()
event.record()
time.sleep(0.1)
self.assertFalse(event.query())
# resolve the work
work.get_future().set_result(None)
torch.cuda.current_stream().synchronize()
self.assertTrue(event.query())
if __name__ == "__main__":
run_tests()

View File

@ -279,11 +279,12 @@ class Work:
def is_success(self) -> bool: ...
def exception(self) -> Any: ...
def wait(self, timeout: timedelta = ...) -> bool: ...
def block_current_stream(self) -> None: ...
def get_future(self) -> Future: ...
def source_rank(self) -> int: ...
def _source_rank(self) -> int: ...
def result(self) -> list[Tensor]: ...
def synchronize(self): ...
def synchronize(self) -> None: ...
def boxed(self) -> ScriptObject: ...
@staticmethod
def unbox(obj: ScriptObject) -> Work: ...

View File

@ -349,6 +349,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// or timed out. If timeout, exception will be thrown.
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void blockCurrentStream() override {
synchronize();
}
void abort() override;
// Let current stream wait on the completion of the NCCL work

View File

@ -1,5 +1,6 @@
#include <ATen/ThreadLocalState.h>
#include <distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility>
@ -100,6 +101,15 @@ bool Work::wait(std::chrono::milliseconds timeout) {
return true;
}
void Work::blockCurrentStream() {
// block cuda stream indefinitely until work is completed.
std::shared_ptr<c10d::cuda::StreamBlock> handle =
c10d::cuda::block_stream(std::chrono::milliseconds(0));
getFuture()->addCallback(
[handle](c10::ivalue::Future& future) { handle->abort(); });
}
void Work::abort() {
TORCH_CHECK(false, "Work::abort not implemented.");
}

View File

@ -110,6 +110,13 @@ class TORCH_API Work : public torch::CustomClassHolder {
//
virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
// Blocks the current stream until the work is completed.
// This is equivalent to synchronize for CUDA tensors but works for both CPU
// tensors and CUDA tensors by using a spinlock CUDA kernel.
// This will immediately return.
// If no stream is active it will throw an error.
virtual void blockCurrentStream();
virtual void abort();
// Returns a Future object that will be associated with the completion of

View File

@ -0,0 +1,14 @@
#include <c10/util/Exception.h>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
namespace c10d::cuda {
C10_DEFINE_REGISTRY(StreamBlockRegistry, StreamBlock, std::chrono::milliseconds)
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout) {
auto baton = StreamBlockRegistry()->Create("CUDA", timeout);
TORCH_CHECK(baton, "Failed to create StreamBlock");
return baton;
}
} // namespace c10d::cuda

View File

@ -0,0 +1,69 @@
#include <ATen/native/TensorFactories.h>
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace c10d::cuda::detail {
__device__ void nanosleep(int64_t ns) {
// This is a noop on pre-CUDA-7.0 and ROCm devices and effectively falls back
// to a spinlock. This only can sleep for a max of 1ms on CUDA devices.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
__nanosleep(ns);
#endif
}
__global__
// set launch bounds to limit to 1 thread per block, 1 block per MP
__launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) {
value[1] = StreamBlockStatus::RUNNING;
size_t start = c10d::symmetric_memory::global_timer_ns();
size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds
while (true) {
// Atomically read the value
int current_value = atomicAdd(&value[0], 0);
// Check if the value is equal to the expected value
if (current_value == 1) {
value[1] = StreamBlockStatus::ABORTED;
return;
}
if (timeout_ms > 0) {
// Check if timeout has been reached
size_t now = c10d::symmetric_memory::global_timer_ns();
if ((now - start) > timeout_ns) {
value[1] = StreamBlockStatus::TIMED_OUT;
return;
}
}
// sleep for 1ms
nanosleep(1000000);
}
}
StreamBlock::StreamBlock(std::chrono::milliseconds timeout)
: comm_{
// We need to pin the memory since we access the CPU memory directly form
// the GPU.
at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory()
},
timeout_{timeout} {
// grid size 1, block size 1, 0 bytes of shared memory
kernel_barrier<<<1, 1, 0>>>(
comm_.mutable_data_ptr<int32_t>(), timeout_.count());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock)
} // namespace c10d::cuda::detail

View File

@ -0,0 +1,29 @@
#pragma once
#include <chrono>
#include <ATen/core/Tensor.h>
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
namespace c10d::cuda::detail {
class StreamBlock : public ::c10d::cuda::StreamBlock {
public:
StreamBlock(std::chrono::milliseconds timeout);
void abort() override {
comm_[0] = 1;
}
StreamBlockStatus status() override {
return static_cast<StreamBlockStatus>(comm_[1].item<int32_t>());
}
private:
// (abort, cycles)
const at::Tensor comm_;
const std::chrono::milliseconds timeout_;
};
} // namespace c10d::cuda::detail

View File

@ -0,0 +1,38 @@
#pragma once
#include <chrono>
#include <memory>
#include <c10/util/Registry.h>
namespace c10d::cuda {
enum StreamBlockStatus : int32_t {
UNKNOWN = 0,
RUNNING = 1,
TIMED_OUT = 2,
ABORTED = 3,
};
/*
StreamBlock implements a baton that will block a the active CUDA stream
until aborted by the main process.
*/
class TORCH_API StreamBlock {
public:
virtual ~StreamBlock() = default;
virtual void abort() = 0;
virtual StreamBlockStatus status() = 0;
};
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout);
// Declare a registry so we can call the CUDA StreamBlock API from CPU only code
// (i.e. ProcessGroup/Work objects in libtorch_cpu).
// The implementation lives defined in StreamBlock.cu.
TORCH_DECLARE_REGISTRY(
StreamBlockRegistry,
StreamBlock,
std::chrono::milliseconds);
} // namespace c10d::cuda

View File

@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`.
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
or timed out. If timeout, exception will be thrown.
)")
.def(
"block_current_stream",
&::c10d::Work::blockCurrentStream,
py::call_guard<py::gil_scoped_release>(),
R"(
Blocks the currently active GPU stream on the operation to
complete. For GPU based collectives this is equivalent to
synchronize. For CPU initiated collectives such as with Gloo this
will block the CUDA stream until the operation is complete.
This returns immediately in all cases.
To check whether an operation was successful you should check the
Work object result asynchronously.
)")
.def(
"get_future_result",
[](::c10d::Work& work)