mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Work: block_current_stream API (#156883)
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
92f41ccc26
commit
1b3d69b59f
@ -518,6 +518,7 @@ libtorch_distributed_base_sources = [
|
|||||||
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
"torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp",
|
||||||
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
|
"torch/csrc/distributed/c10d/control_plane/Handlers.cpp",
|
||||||
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
|
"torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/cuda/StreamBlock.cpp",
|
||||||
"torch/csrc/distributed/c10d/debug.cpp",
|
"torch/csrc/distributed/c10d/debug.cpp",
|
||||||
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
"torch/csrc/distributed/c10d/default_comm_hooks.cpp",
|
||||||
"torch/csrc/distributed/c10d/logger.cpp",
|
"torch/csrc/distributed/c10d/logger.cpp",
|
||||||
@ -734,6 +735,7 @@ libtorch_cuda_distributed_extra_sources = [
|
|||||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||||
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
||||||
"torch/csrc/distributed/c10d/cuda/utils.cpp",
|
"torch/csrc/distributed/c10d/cuda/utils.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
|
||||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
||||||
|
@ -1562,6 +1562,21 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
|||||||
for i, tensor in enumerate(tensors):
|
for i, tensor in enumerate(tensors):
|
||||||
self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
|
self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
|
||||||
|
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
@requires_gloo()
|
||||||
|
@skipIfRocm
|
||||||
|
def test_block_current_stream_cuda(self):
|
||||||
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
|
pg = self._create_process_group_gloo(
|
||||||
|
store, self.rank, self.world_size, self.opts()
|
||||||
|
)
|
||||||
|
t = torch.zeros(10, device="cuda")
|
||||||
|
work = pg.allreduce(t)
|
||||||
|
work.block_current_stream()
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
work.wait()
|
||||||
|
|
||||||
|
|
||||||
class DistributedDataParallelTest(
|
class DistributedDataParallelTest(
|
||||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||||
|
@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
|||||||
)
|
)
|
||||||
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
|
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
|
def test_block_current_stream(self):
|
||||||
|
store = c10d.FileStore(self.file_name, self.world_size)
|
||||||
|
device = torch.device(f"cuda:{self.rank}")
|
||||||
|
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||||
|
|
||||||
|
t = torch.rand(10, device=device)
|
||||||
|
work = pg.allreduce(t)
|
||||||
|
work.block_current_stream()
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
work.wait()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
class DistributedDataParallelTest(
|
class DistributedDataParallelTest(
|
||||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import test_c10d_common
|
import test_c10d_common
|
||||||
@ -10,6 +12,7 @@ import torch.nn as nn
|
|||||||
from torch._C._distributed_c10d import _create_work_from_future
|
from torch._C._distributed_c10d import _create_work_from_future
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
from torch.testing._internal.common_distributed import MultiThreadedTestCase
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
@ -197,6 +200,37 @@ class TestPyProcessGroup(TestCase):
|
|||||||
pg.abort()
|
pg.abort()
|
||||||
pg.shutdown()
|
pg.shutdown()
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_CUDA, "no cuda/xpu")
|
||||||
|
def test_block_current_stream(self) -> None:
|
||||||
|
class BlockWork(dist._Work):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.future_ = torch.futures.Future()
|
||||||
|
|
||||||
|
def get_future(self):
|
||||||
|
return self.future_
|
||||||
|
|
||||||
|
# nothing in queue so instantly resolves
|
||||||
|
event1 = torch.cuda.Event()
|
||||||
|
event1.record()
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.assertTrue(event1.query())
|
||||||
|
|
||||||
|
work = BlockWork()
|
||||||
|
work.block_current_stream()
|
||||||
|
|
||||||
|
# stream is blocked so doesn't resolve
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
event.record()
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.assertFalse(event.query())
|
||||||
|
|
||||||
|
# resolve the work
|
||||||
|
work.get_future().set_result(None)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
self.assertTrue(event.query())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -279,11 +279,12 @@ class Work:
|
|||||||
def is_success(self) -> bool: ...
|
def is_success(self) -> bool: ...
|
||||||
def exception(self) -> Any: ...
|
def exception(self) -> Any: ...
|
||||||
def wait(self, timeout: timedelta = ...) -> bool: ...
|
def wait(self, timeout: timedelta = ...) -> bool: ...
|
||||||
|
def block_current_stream(self) -> None: ...
|
||||||
def get_future(self) -> Future: ...
|
def get_future(self) -> Future: ...
|
||||||
def source_rank(self) -> int: ...
|
def source_rank(self) -> int: ...
|
||||||
def _source_rank(self) -> int: ...
|
def _source_rank(self) -> int: ...
|
||||||
def result(self) -> list[Tensor]: ...
|
def result(self) -> list[Tensor]: ...
|
||||||
def synchronize(self): ...
|
def synchronize(self) -> None: ...
|
||||||
def boxed(self) -> ScriptObject: ...
|
def boxed(self) -> ScriptObject: ...
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def unbox(obj: ScriptObject) -> Work: ...
|
def unbox(obj: ScriptObject) -> Work: ...
|
||||||
|
@ -349,6 +349,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
|||||||
// or timed out. If timeout, exception will be thrown.
|
// or timed out. If timeout, exception will be thrown.
|
||||||
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
|
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
|
||||||
|
|
||||||
|
void blockCurrentStream() override {
|
||||||
|
synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
void abort() override;
|
void abort() override;
|
||||||
|
|
||||||
// Let current stream wait on the completion of the NCCL work
|
// Let current stream wait on the completion of the NCCL work
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include <ATen/ThreadLocalState.h>
|
#include <ATen/ThreadLocalState.h>
|
||||||
#include <distributed/c10d/ProcessGroup.hpp>
|
#include <distributed/c10d/ProcessGroup.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/Work.hpp>
|
#include <torch/csrc/distributed/c10d/Work.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -100,6 +101,15 @@ bool Work::wait(std::chrono::milliseconds timeout) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Work::blockCurrentStream() {
|
||||||
|
// block cuda stream indefinitely until work is completed.
|
||||||
|
std::shared_ptr<c10d::cuda::StreamBlock> handle =
|
||||||
|
c10d::cuda::block_stream(std::chrono::milliseconds(0));
|
||||||
|
|
||||||
|
getFuture()->addCallback(
|
||||||
|
[handle](c10::ivalue::Future& future) { handle->abort(); });
|
||||||
|
}
|
||||||
|
|
||||||
void Work::abort() {
|
void Work::abort() {
|
||||||
TORCH_CHECK(false, "Work::abort not implemented.");
|
TORCH_CHECK(false, "Work::abort not implemented.");
|
||||||
}
|
}
|
||||||
|
@ -110,6 +110,13 @@ class TORCH_API Work : public torch::CustomClassHolder {
|
|||||||
//
|
//
|
||||||
virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
|
virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
|
||||||
|
|
||||||
|
// Blocks the current stream until the work is completed.
|
||||||
|
// This is equivalent to synchronize for CUDA tensors but works for both CPU
|
||||||
|
// tensors and CUDA tensors by using a spinlock CUDA kernel.
|
||||||
|
// This will immediately return.
|
||||||
|
// If no stream is active it will throw an error.
|
||||||
|
virtual void blockCurrentStream();
|
||||||
|
|
||||||
virtual void abort();
|
virtual void abort();
|
||||||
|
|
||||||
// Returns a Future object that will be associated with the completion of
|
// Returns a Future object that will be associated with the completion of
|
||||||
|
14
torch/csrc/distributed/c10d/cuda/StreamBlock.cpp
Normal file
14
torch/csrc/distributed/c10d/cuda/StreamBlock.cpp
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||||
|
|
||||||
|
namespace c10d::cuda {
|
||||||
|
|
||||||
|
C10_DEFINE_REGISTRY(StreamBlockRegistry, StreamBlock, std::chrono::milliseconds)
|
||||||
|
|
||||||
|
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout) {
|
||||||
|
auto baton = StreamBlockRegistry()->Create("CUDA", timeout);
|
||||||
|
TORCH_CHECK(baton, "Failed to create StreamBlock");
|
||||||
|
return baton;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10d::cuda
|
69
torch/csrc/distributed/c10d/cuda/StreamBlock.cu
Normal file
69
torch/csrc/distributed/c10d/cuda/StreamBlock.cu
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#include <ATen/native/TensorFactories.h>
|
||||||
|
#include <c10/cuda/CUDAException.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.cuh>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace c10d::cuda::detail {
|
||||||
|
|
||||||
|
__device__ void nanosleep(int64_t ns) {
|
||||||
|
// This is a noop on pre-CUDA-7.0 and ROCm devices and effectively falls back
|
||||||
|
// to a spinlock. This only can sleep for a max of 1ms on CUDA devices.
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
|
||||||
|
__nanosleep(ns);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__
|
||||||
|
// set launch bounds to limit to 1 thread per block, 1 block per MP
|
||||||
|
__launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) {
|
||||||
|
value[1] = StreamBlockStatus::RUNNING;
|
||||||
|
|
||||||
|
size_t start = c10d::symmetric_memory::global_timer_ns();
|
||||||
|
size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds
|
||||||
|
while (true) {
|
||||||
|
// Atomically read the value
|
||||||
|
int current_value = atomicAdd(&value[0], 0);
|
||||||
|
// Check if the value is equal to the expected value
|
||||||
|
if (current_value == 1) {
|
||||||
|
value[1] = StreamBlockStatus::ABORTED;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timeout_ms > 0) {
|
||||||
|
// Check if timeout has been reached
|
||||||
|
size_t now = c10d::symmetric_memory::global_timer_ns();
|
||||||
|
if ((now - start) > timeout_ns) {
|
||||||
|
value[1] = StreamBlockStatus::TIMED_OUT;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sleep for 1ms
|
||||||
|
nanosleep(1000000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamBlock::StreamBlock(std::chrono::milliseconds timeout)
|
||||||
|
: comm_{
|
||||||
|
// We need to pin the memory since we access the CPU memory directly form
|
||||||
|
// the GPU.
|
||||||
|
at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory()
|
||||||
|
},
|
||||||
|
timeout_{timeout} {
|
||||||
|
// grid size 1, block size 1, 0 bytes of shared memory
|
||||||
|
kernel_barrier<<<1, 1, 0>>>(
|
||||||
|
comm_.mutable_data_ptr<int32_t>(), timeout_.count());
|
||||||
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
}
|
||||||
|
|
||||||
|
C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock)
|
||||||
|
|
||||||
|
} // namespace c10d::cuda::detail
|
29
torch/csrc/distributed/c10d/cuda/StreamBlock.cuh
Normal file
29
torch/csrc/distributed/c10d/cuda/StreamBlock.cuh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include <ATen/core/Tensor.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.hpp>
|
||||||
|
|
||||||
|
namespace c10d::cuda::detail {
|
||||||
|
|
||||||
|
class StreamBlock : public ::c10d::cuda::StreamBlock {
|
||||||
|
public:
|
||||||
|
StreamBlock(std::chrono::milliseconds timeout);
|
||||||
|
|
||||||
|
void abort() override {
|
||||||
|
comm_[0] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamBlockStatus status() override {
|
||||||
|
return static_cast<StreamBlockStatus>(comm_[1].item<int32_t>());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// (abort, cycles)
|
||||||
|
const at::Tensor comm_;
|
||||||
|
const std::chrono::milliseconds timeout_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10d::cuda::detail
|
38
torch/csrc/distributed/c10d/cuda/StreamBlock.hpp
Normal file
38
torch/csrc/distributed/c10d/cuda/StreamBlock.hpp
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <c10/util/Registry.h>
|
||||||
|
|
||||||
|
namespace c10d::cuda {
|
||||||
|
|
||||||
|
enum StreamBlockStatus : int32_t {
|
||||||
|
UNKNOWN = 0,
|
||||||
|
RUNNING = 1,
|
||||||
|
TIMED_OUT = 2,
|
||||||
|
ABORTED = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
StreamBlock implements a baton that will block a the active CUDA stream
|
||||||
|
until aborted by the main process.
|
||||||
|
*/
|
||||||
|
class TORCH_API StreamBlock {
|
||||||
|
public:
|
||||||
|
virtual ~StreamBlock() = default;
|
||||||
|
virtual void abort() = 0;
|
||||||
|
virtual StreamBlockStatus status() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unique_ptr<StreamBlock> block_stream(std::chrono::milliseconds timeout);
|
||||||
|
|
||||||
|
// Declare a registry so we can call the CUDA StreamBlock API from CPU only code
|
||||||
|
// (i.e. ProcessGroup/Work objects in libtorch_cpu).
|
||||||
|
// The implementation lives defined in StreamBlock.cu.
|
||||||
|
TORCH_DECLARE_REGISTRY(
|
||||||
|
StreamBlockRegistry,
|
||||||
|
StreamBlock,
|
||||||
|
std::chrono::milliseconds);
|
||||||
|
|
||||||
|
} // namespace c10d::cuda
|
@ -3527,6 +3527,21 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
|||||||
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
|
However, if timeout is set, it will block the CPU thread until the NCCL work is completed
|
||||||
or timed out. If timeout, exception will be thrown.
|
or timed out. If timeout, exception will be thrown.
|
||||||
)")
|
)")
|
||||||
|
.def(
|
||||||
|
"block_current_stream",
|
||||||
|
&::c10d::Work::blockCurrentStream,
|
||||||
|
py::call_guard<py::gil_scoped_release>(),
|
||||||
|
R"(
|
||||||
|
Blocks the currently active GPU stream on the operation to
|
||||||
|
complete. For GPU based collectives this is equivalent to
|
||||||
|
synchronize. For CPU initiated collectives such as with Gloo this
|
||||||
|
will block the CUDA stream until the operation is complete.
|
||||||
|
|
||||||
|
This returns immediately in all cases.
|
||||||
|
|
||||||
|
To check whether an operation was successful you should check the
|
||||||
|
Work object result asynchronously.
|
||||||
|
)")
|
||||||
.def(
|
.def(
|
||||||
"get_future_result",
|
"get_future_result",
|
||||||
[](::c10d::Work& work)
|
[](::c10d::Work& work)
|
||||||
|
Reference in New Issue
Block a user