mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] block_current_stream: correctness fixes (#158757)
This fixes a number of issues that were present in https://github.com/pytorch/pytorch/pull/156883 as pointed out by @ngimel Test plan: Expanded tests to cover use after free behavior + non-default stream ``` pytest test/distributed/test_c10d_pypg.py -v -k block_current_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158757 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd0adc9386
commit
4366610f5a
@ -181,6 +181,19 @@ class TestDDPWithWorkWrapper(AbstractDDPSingleRank, MultiThreadedTestCase):
|
||||
return True
|
||||
|
||||
|
||||
class BlockWork(dist._Work):
|
||||
"""
|
||||
Dummy work that is used to test blocking the current stream.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.future_ = torch.futures.Future()
|
||||
|
||||
def get_future(self):
|
||||
return self.future_
|
||||
|
||||
|
||||
class TestPyProcessGroup(TestCase):
|
||||
def test_attr_overrides(self):
|
||||
pg = DummyAttrProcessGroup(0, 1)
|
||||
@ -202,34 +215,61 @@ class TestPyProcessGroup(TestCase):
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda/xpu")
|
||||
def test_block_current_stream(self) -> None:
|
||||
class BlockWork(dist._Work):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.future_ = torch.futures.Future()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def get_future(self):
|
||||
return self.future_
|
||||
stream = torch.cuda.Stream()
|
||||
with stream:
|
||||
# nothing in queue so instantly resolves
|
||||
event1 = torch.cuda.Event()
|
||||
event1.record()
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(event1.query())
|
||||
|
||||
# nothing in queue so instantly resolves
|
||||
event1 = torch.cuda.Event()
|
||||
event1.record()
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(event1.query())
|
||||
work = BlockWork()
|
||||
work.block_current_stream()
|
||||
|
||||
work = BlockWork()
|
||||
work.block_current_stream()
|
||||
# stream is blocked so doesn't resolve
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(event.query())
|
||||
|
||||
# stream is blocked so doesn't resolve
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(event.query())
|
||||
# resolve the work
|
||||
work.get_future().set_result(None)
|
||||
|
||||
# resolve the work
|
||||
work.get_future().set_result(None)
|
||||
stream.synchronize()
|
||||
self.assertTrue(event.query())
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.assertTrue(event.query())
|
||||
@unittest.skipIf(not TEST_CUDA, "no cuda/xpu")
|
||||
def test_block_current_stream_use_after_free(self) -> None:
|
||||
"""
|
||||
This tests that the CPU control tensor is not freed before the CUDA kernel executes.
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
with stream:
|
||||
a = BlockWork()
|
||||
a.block_current_stream()
|
||||
|
||||
b = BlockWork()
|
||||
b.block_current_stream()
|
||||
|
||||
# unblock b first though a is still blocking
|
||||
b.get_future().set_result(None)
|
||||
# delete b
|
||||
del b
|
||||
|
||||
# a is still blocking so this doesn't resolve
|
||||
event = torch.cuda.Event()
|
||||
event.record()
|
||||
time.sleep(0.1)
|
||||
self.assertFalse(event.query())
|
||||
|
||||
# unblock a
|
||||
a.get_future().set_result(None)
|
||||
|
||||
stream.synchronize()
|
||||
self.assertTrue(event.query())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <ATen/cuda/CachingHostAllocator.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/cuda/StreamBlock.cuh>
|
||||
@ -8,7 +10,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
namespace c10d::cuda::detail {
|
||||
@ -21,19 +23,49 @@ __device__ void nanosleep(int64_t ns) {
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ int32_t load_cpu_int32(int32_t* ptr) {
|
||||
#if defined(USE_ROCM)
|
||||
// WARNING: this may not be safe
|
||||
return atomicAdd_system(ptr, 0);
|
||||
#else
|
||||
int32_t current_value = 0;
|
||||
|
||||
// Bypass L1 cache to see updates at L2 and above.
|
||||
// This could use .cv to bypass L2 cache but that's significantly more
|
||||
// expensive and the CPU write will clear the L2 cache.
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators
|
||||
asm volatile("ld.cg.s32 %0, [%1];"
|
||||
: "=r"(current_value) // Output operand
|
||||
: "l"(ptr) // Input operand
|
||||
);
|
||||
return current_value;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void store_cpu_int32(int32_t* ptr, int32_t val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))
|
||||
// WARNING: this value may be cached without .release
|
||||
*ptr = val;
|
||||
#else
|
||||
// Releases memory so it can be seen by other threads on the system.
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#release-acquire-patterns
|
||||
asm volatile("st.release.sys.s32 [%0], %1;" ::"l"(ptr), "r"(val));
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__
|
||||
// set launch bounds to limit to 1 thread per block, 1 block per MP
|
||||
__launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) {
|
||||
value[1] = StreamBlockStatus::RUNNING;
|
||||
store_cpu_int32(&value[1], StreamBlockStatus::RUNNING);
|
||||
|
||||
size_t start = c10d::symmetric_memory::global_timer_ns();
|
||||
size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds
|
||||
while (true) {
|
||||
// Atomically read the value
|
||||
int current_value = atomicAdd(&value[0], 0);
|
||||
int32_t current_value = load_cpu_int32(value);
|
||||
// Check if the value is equal to the expected value
|
||||
if (current_value == 1) {
|
||||
value[1] = StreamBlockStatus::ABORTED;
|
||||
store_cpu_int32(&value[1], StreamBlockStatus::ABORTED);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -41,7 +73,7 @@ __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) {
|
||||
// Check if timeout has been reached
|
||||
size_t now = c10d::symmetric_memory::global_timer_ns();
|
||||
if ((now - start) > timeout_ns) {
|
||||
value[1] = StreamBlockStatus::TIMED_OUT;
|
||||
store_cpu_int32(&value[1], StreamBlockStatus::TIMED_OUT);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -55,13 +87,21 @@ StreamBlock::StreamBlock(std::chrono::milliseconds timeout)
|
||||
: comm_{
|
||||
// We need to pin the memory since we access the CPU memory directly form
|
||||
// the GPU.
|
||||
at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory()
|
||||
at::zeros({2}, at::TensorOptions().dtype(at::kInt)).pin_memory()
|
||||
},
|
||||
timeout_{timeout} {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
auto* ptr = comm_.mutable_data_ptr<int32_t>();
|
||||
auto* ctx = comm_.storage().data_ptr().get_context();
|
||||
|
||||
// grid size 1, block size 1, 0 bytes of shared memory
|
||||
kernel_barrier<<<1, 1, 0>>>(
|
||||
comm_.mutable_data_ptr<int32_t>(), timeout_.count());
|
||||
kernel_barrier<<<1, 1, 0, stream>>>(ptr, timeout_.count());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// This object may be deallocated before the CUDA kernel completes. We need to
|
||||
// register the CPU tensor so it's only freed after the kernel completes
|
||||
// execution.
|
||||
at::getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap());
|
||||
}
|
||||
|
||||
C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock)
|
||||
|
@ -13,6 +13,7 @@ class StreamBlock : public ::c10d::cuda::StreamBlock {
|
||||
StreamBlock(std::chrono::milliseconds timeout);
|
||||
|
||||
void abort() override {
|
||||
std::atomic_thread_fence(std::memory_order_seq_cst);
|
||||
comm_[0] = 1;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user