mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Work: block_current_stream API (#156883)
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
92f41ccc26
commit
1b3d69b59f
@ -518,6 +518,7 @@ libtorch_distributed_base_sources = [
|
||||
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
||||
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
|
||||
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/StreamBlock.cpp",
|
||||
"torch/csrc/distributed/c10d/debug.cpp",
|
||||
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
||||
"torch/csrc/distributed/c10d/logger.cpp",
|
||||
@ -734,6 +735,7 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
||||
"torch/csrc/distributed/c10d/cuda/utils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
||||
|
@ -1562,6 +1562,21 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
for i, tensor in enumerate(tensors):
|
||||
self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_gloo()
|
||||
@skipIfRocm
|
||||
def test_block_current_stream_cuda(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_gloo(
|
||||
store, self.rank, self.world_size, self.opts()
|
||||
)
|
||||
t = torch.zeros(10, device="cuda")
|
||||
work = pg.allreduce(t)
|
||||
work.block_current_stream()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
work.wait()
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
|
@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
)
|
||||
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_block_current_stream(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
t = torch.rand(10, device=device)
|
||||
work = pg.allreduce(t)
|
||||
work.block_current_stream()
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
work.wait()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import time
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import test_c10d_common
|
||||
@ -10,6 +12,7 @@ import torch.nn as nn
|
||||
from torch._C._distributed_c10d import _create_work_from_future
|
||||
from torch.futures import Future
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
@ -197,6 +200,37 @@ class TestPyProcessGroup(TestCase):
|
||||
pg.abort()
|
||||
pg.shutdown()
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda/xpu")
|
||||
def test_block_current_stream(self) -> None:
|
||||
class BlockWork(dist._Work):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.future_ = torch.futures.Future()
|
||||
|
||||
def get_future(self):
|
||||
return self.future_
|
||||
|
||||
# nothing in queue so instantly resolves
|
||||
event1 = torch.cuda.Event()
|
||||
event1.record()
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(event1.query())
|
||||
|
||||
work = BlockWork()
|
||||
work.block_current_stream()
|
||||
|
||||
# stream is blocked so doesn't resolve
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(event.query())
|
||||
|
||||
# resolve the work
|
||||
work.get_future().set_result(None)
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.assertTrue(event.query())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -279,11 +279,12 @@ class Work:
|
||||
def is_success(self) -> bool: ...
|
||||
def exception(self) -> Any: ...
|
||||
def wait(self, timeout: timedelta = ...) -> bool: ...
|
||||
def block_current_stream(self) -> None: ...
|
||||
def get_future(self) -> Future: ...
|
||||
def source_rank(self) -> int: ...
|
||||
def _source_rank(self) -> int: ...
|
||||
def result(self) -> list[Tensor]: ...
|
||||
def synchronize(self): ...
|
||||
def synchronize(self) -> None: ...
|
||||
def boxed(self) -> ScriptObject: ...
|
||||
@staticmethod
|
||||
def unbox(obj: ScriptObject) -> Work: ...
|
||||
|
@ -349,6 +349,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// or timed out. If timeout, exception will be thrown.
|
||||
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
|
||||
|
||||
void blockCurrentStream() override {
|
||||
synchronize();
|
||||
}
|
||||
|
||||
void abort() override;
|
||||
|
||||
// Let current stream wait on the completion of the NCCL work
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <ATen/ThreadLocalState.h>
|
||||
#include <distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||
#include <utility>
|
||||
@ -100,6 +101,15 @@ bool Work::wait(std::chrono::milliseconds timeout) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void Work::blockCurrentStream() {
|
||||
// block cuda stream indefinitely until work is completed.
|
||||
std::shared_ptr<c10d::cuda::StreamBlock> handle =
|
||||
c10d::cuda::block_stream(std::chrono::milliseconds(0));
|
||||
|
||||
getFuture()->addCallback(
|
||||
[handle](c10::ivalue::Future& future) { handle->abort(); });
|
||||
}
|
||||
|
||||
void Work::abort() {
|
||||
TORCH_CHECK(false, "Work::abort not implemented.");
|
||||
}
|
||||
|
@ -110,6 +110,13 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
||||
//
|
||||
virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
|
||||
|
||||
// Blocks the current stream until the work is completed.
|
||||
// This is equivalent to synchronize for CUDA tensors but works for both CPU
|
||||
// tensors and CUDA tensors by using a spinlock CUDA kernel.
|
||||
// This will immediately return.
|
||||
// If no stream is active it will throw an error.
|
||||
virtual void blockCurrentStream();
|
||||
|
||||
virtual void abort();
|
||||
|
||||
// Returns a Future object that will be associated with the completion of
|
||||
|
14
torch/csrc/distributed/c10d/cuda/StreamBlock.cpp
Normal file
14
torch/csrc/distributed/c10d/cuda/StreamBlock.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||
|
||||
namespace c10d::cuda {
|
||||
|
||||
C10_DEFINE_REGISTRY(StreamBlockRegistry, StreamBlock, std::chrono::milliseconds)
|
||||
|
||||
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout) {
|
||||
auto baton = StreamBlockRegistry()->Create("CUDA", timeout);
|
||||
TORCH_CHECK(baton, "Failed to create StreamBlock");
|
||||
return baton;
|
||||
}
|
||||
|
||||
} // namespace c10d::cuda
|
69
torch/csrc/distributed/c10d/cuda/StreamBlock.cu
Normal file
69
torch/csrc/distributed/c10d/cuda/StreamBlock.cu
Normal file
@ -0,0 +1,69 @@
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
namespace c10d::cuda::detail {
|
||||
|
||||
__device__ void nanosleep(int64_t ns) {
|
||||
// This is a noop on pre-CUDA-7.0 and ROCm devices and effectively falls back
|
||||
// to a spinlock. This only can sleep for a max of 1ms on CUDA devices.
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
|
||||
__nanosleep(ns);
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__
|
||||
// set launch bounds to limit to 1 thread per block, 1 block per MP
|
||||
__launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) {
|
||||
value[1] = StreamBlockStatus::RUNNING;
|
||||
|
||||
size_t start = c10d::symmetric_memory::global_timer_ns();
|
||||
size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds
|
||||
while (true) {
|
||||
// Atomically read the value
|
||||
int current_value = atomicAdd(&value[0], 0);
|
||||
// Check if the value is equal to the expected value
|
||||
if (current_value == 1) {
|
||||
value[1] = StreamBlockStatus::ABORTED;
|
||||
return;
|
||||
}
|
||||
|
||||
if (timeout_ms > 0) {
|
||||
// Check if timeout has been reached
|
||||
size_t now = c10d::symmetric_memory::global_timer_ns();
|
||||
if ((now - start) > timeout_ns) {
|
||||
value[1] = StreamBlockStatus::TIMED_OUT;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// sleep for 1ms
|
||||
nanosleep(1000000);
|
||||
}
|
||||
}
|
||||
|
||||
StreamBlock::StreamBlock(std::chrono::milliseconds timeout)
|
||||
: comm_{
|
||||
// We need to pin the memory since we access the CPU memory directly form
|
||||
// the GPU.
|
||||
at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory()
|
||||
},
|
||||
timeout_{timeout} {
|
||||
// grid size 1, block size 1, 0 bytes of shared memory
|
||||
kernel_barrier<<<1, 1, 0>>>(
|
||||
comm_.mutable_data_ptr<int32_t>(), timeout_.count());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock)
|
||||
|
||||
} // namespace c10d::cuda::detail
|
29
torch/csrc/distributed/c10d/cuda/StreamBlock.cuh
Normal file
29
torch/csrc/distributed/c10d/cuda/StreamBlock.cuh
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||
|
||||
namespace c10d::cuda::detail {
|
||||
|
||||
class StreamBlock : public ::c10d::cuda::StreamBlock {
|
||||
public:
|
||||
StreamBlock(std::chrono::milliseconds timeout);
|
||||
|
||||
void abort() override {
|
||||
comm_[0] = 1;
|
||||
}
|
||||
|
||||
StreamBlockStatus status() override {
|
||||
return static_cast<StreamBlockStatus>(comm_[1].item<int32_t>());
|
||||
}
|
||||
|
||||
private:
|
||||
// (abort, cycles)
|
||||
const at::Tensor comm_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
};
|
||||
|
||||
} // namespace c10d::cuda::detail
|
38
torch/csrc/distributed/c10d/cuda/StreamBlock.hpp
Normal file
38
torch/csrc/distributed/c10d/cuda/StreamBlock.hpp
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
namespace c10d::cuda {
|
||||
|
||||
enum StreamBlockStatus : int32_t {
|
||||
UNKNOWN = 0,
|
||||
RUNNING = 1,
|
||||
TIMED_OUT = 2,
|
||||
ABORTED = 3,
|
||||
};
|
||||
|
||||
/*
|
||||
StreamBlock implements a baton that will block a the active CUDA stream
|
||||
until aborted by the main process.
|
||||
*/
|
||||
class TORCH_API StreamBlock {
|
||||
public:
|
||||
virtual ~StreamBlock() = default;
|
||||
virtual void abort() = 0;
|
||||
virtual StreamBlockStatus status() = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout);
|
||||
|
||||
// Declare a registry so we can call the CUDA StreamBlock API from CPU only code
|
||||
// (i.e. ProcessGroup/Work objects in libtorch_cpu).
|
||||
// The implementation lives defined in StreamBlock.cu.
|
||||
TORCH_DECLARE_REGISTRY(
|
||||
StreamBlockRegistry,
|
||||
StreamBlock,
|
||||
std::chrono::milliseconds);
|
||||
|
||||
} // namespace c10d::cuda
|
@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
|
||||
or timed out. If timeout, exception will be thrown.
|
||||
)")
|
||||
.def(
|
||||
"block_current_stream",
|
||||
&::c10d::Work::blockCurrentStream,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
Blocks the currently active GPU stream on the operation to
|
||||
complete. For GPU based collectives this is equivalent to
|
||||
synchronize. For CPU initiated collectives such as with Gloo this
|
||||
will block the CUDA stream until the operation is complete.
|
||||
|
||||
This returns immediately in all cases.
|
||||
|
||||
To check whether an operation was successful you should check the
|
||||
Work object result asynchronously.
|
||||
)")
|
||||
.def(
|
||||
"get_future_result",
|
||||
[](::c10d::Work& work)
|
||||
|
Reference in New Issue
Block a user