diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py index 6ccb81b116ca..65faf2075daa 100644 --- a/test/distributed/test_c10d_pypg.py +++ b/test/distributed/test_c10d_pypg.py @@ -181,6 +181,19 @@ class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiThreadedTestCase): return True +class BlockWork(dist._Work): + """ + Dummy work that is used to test blocking the current stream. + """ + + def __init__(self): + super().__init__() + self.future_ = torch.futures.Future() + + def get_future(self): + return self.future_ + + class TestPyProcessGroup(TestCase): def test_attr_overrides(self): pg = DummyAttrProcessGroup(0, 1) @@ -202,34 +215,61 @@ class TestPyProcessGroup(TestCase): @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") def test_block_current_stream(self) -> None: - class BlockWork(dist._Work): - def __init__(self): - super().__init__() - self.future_ = torch.futures.Future() + torch.cuda.synchronize() - def get_future(self): - return self.future_ + stream = torch.cuda.Stream() + with stream: + # nothing in queue so instantly resolves + event1 = torch.cuda.Event() + event1.record() + time.sleep(0.1) + self.assertTrue(event1.query()) - # nothing in queue so instantly resolves - event1 = torch.cuda.Event() - event1.record() - time.sleep(0.1) - self.assertTrue(event1.query()) + work = BlockWork() + work.block_current_stream() - work = BlockWork() - work.block_current_stream() + # stream is blocked so doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) - # stream is blocked so doesn't resolve - event = torch.cuda.Event() - event.record() - time.sleep(0.1) - self.assertFalse(event.query()) + # resolve the work + work.get_future().set_result(None) - # resolve the work - work.get_future().set_result(None) + stream.synchronize() + self.assertTrue(event.query()) - torch.cuda.current_stream().synchronize() - self.assertTrue(event.query()) + @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") + def test_block_current_stream_use_after_free(self) -> None: + """ + This tests that the CPU control tensor is not freed before the CUDA kernel executes. + """ + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with stream: + a = BlockWork() + a.block_current_stream() + + b = BlockWork() + b.block_current_stream() + + # unblock b first though a is still blocking + b.get_future().set_result(None) + # delete b + del b + + # a is still blocking so this doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # unblock a + a.get_future().set_result(None) + + stream.synchronize() + self.assertTrue(event.query()) if __name__ == "__main__": diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu index 58533ece6af8..db4a118a25e5 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -8,7 +10,7 @@ #include #include #else -#include +#include #endif namespace c10d::cuda::detail { @@ -21,19 +23,49 @@ __device__ void nanosleep(int64_t ns) { #endif } +__device__ int32_t load_cpu_int32(int32_t* ptr) { +#if defined(USE_ROCM) + // WARNING: this may not be safe + return atomicAdd_system(ptr, 0); +#else + int32_t current_value = 0; + + // Bypass L1 cache to see updates at L2 and above. + // This could use .cv to bypass L2 cache but that's significantly more + // expensive and the CPU write will clear the L2 cache. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators + asm volatile("ld.cg.s32 %0, [%1];" + : "=r"(current_value) // Output operand + : "l"(ptr) // Input operand + ); + return current_value; +#endif +} + +__device__ void store_cpu_int32(int32_t* ptr, int32_t val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)) + // WARNING: this value may be cached without .release + *ptr = val; +#else + // Releases memory so it can be seen by other threads on the system. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#release-acquire-patterns + asm volatile("st.release.sys.s32 [%0], %1;" ::"l"(ptr), "r"(val)); +#endif +} + __global__ // set launch bounds to limit to 1 thread per block, 1 block per MP __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { - value[1] = StreamBlockStatus::RUNNING; + store_cpu_int32(&value[1], StreamBlockStatus::RUNNING); size_t start = c10d::symmetric_memory::global_timer_ns(); size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds while (true) { // Atomically read the value - int current_value = atomicAdd(&value[0], 0); + int32_t current_value = load_cpu_int32(value); // Check if the value is equal to the expected value if (current_value == 1) { - value[1] = StreamBlockStatus::ABORTED; + store_cpu_int32(&value[1], StreamBlockStatus::ABORTED); return; } @@ -41,7 +73,7 @@ __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { // Check if timeout has been reached size_t now = c10d::symmetric_memory::global_timer_ns(); if ((now - start) > timeout_ns) { - value[1] = StreamBlockStatus::TIMED_OUT; + store_cpu_int32(&value[1], StreamBlockStatus::TIMED_OUT); return; } } @@ -55,13 +87,21 @@ StreamBlock::StreamBlock(std::chrono::milliseconds timeout) : comm_{ // We need to pin the memory since we access the CPU memory directly form // the GPU. - at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() + at::zeros({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() }, timeout_{timeout} { + auto stream = at::cuda::getCurrentCUDAStream(); + auto* ptr = comm_.mutable_data_ptr(); + auto* ctx = comm_.storage().data_ptr().get_context(); + // grid size 1, block size 1, 0 bytes of shared memory - kernel_barrier<<<1, 1, 0>>>( - comm_.mutable_data_ptr(), timeout_.count()); + kernel_barrier<<<1, 1, 0, stream>>>(ptr, timeout_.count()); C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // This object may be deallocated before the CUDA kernel completes. We need to + // register the CPU tensor so it's only freed after the kernel completes + // execution. + at::getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap()); } C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock) diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh index f94f272d7eef..9ca52b4c5e88 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh @@ -13,6 +13,7 @@ class StreamBlock : public ::c10d::cuda::StreamBlock { StreamBlock(std::chrono::milliseconds timeout); void abort() override { + std::atomic_thread_fence(std::memory_order_seq_cst); comm_[0] = 1; }