[c10d] Fix CudaEventCache for dangling references (#144496)

Reported in https://github.com/pytorch/pytorch/issues/143470, we have a dangling references in `CudaEventCache`. So we want to fix it.
1. We add a unit test to repro the issue mentioned in the issue.
2. Instead of converting variables to shared pointers as suggested in the issue, we then make the cache itself a shared pointer. So if the thread creates the cache dies before all events get recycled, the cache is still there until the last CudaEvent get deleted. (thanks for the suggestion from @kwen2501 )

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144496
Approved by: https://github.com/kwen2501
This commit is contained in:
fduwjj
2025-01-14 10:13:28 -08:00
committed by PyTorch MergeBot
parent 9cd6f46130
commit ae7df51232
4 changed files with 58 additions and 27 deletions

View File

@ -767,8 +767,8 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
} }
// Test that the CUDAEventCache can be used to create CUDA events and reuse. // Test that the CUDAEventCache can be used to create CUDA events and reuse.
auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(true); auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(false); auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
auto event1_ptr = event1.get(); auto event1_ptr = event1.get();
auto event2_ptr = event2.get(); auto event2_ptr = event2.get();
@ -777,14 +777,14 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
event2 = nullptr; event2 = nullptr;
// Test that the CUDAEventCache is indeed reused. // Test that the CUDAEventCache is indeed reused.
auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2).create(true); auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true);
auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2).create(false); auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false);
// The cache has been used up, new events should be created. // The cache has been used up, new events should be created.
auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(true); auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(false); auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
// The cache has been used up, new events should be created. // The cache has been used up, new events should be created.
auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(true); auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1).create(false); auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
EXPECT_NE(event1_ptr, event3.get()); EXPECT_NE(event1_ptr, event3.get());
EXPECT_NE(event2_ptr, event4.get()); EXPECT_NE(event2_ptr, event4.get());
EXPECT_EQ(event1_ptr, event5.get()); EXPECT_EQ(event1_ptr, event5.get());

View File

@ -438,6 +438,35 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
dist.all_reduce(t1) dist.all_reduce(t1)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_cuda_event_cache_mthd_race(self):
# This unit test is to test the case when the collective is launched in
# a side thread and the thread dies before the cache has been fully recycled.
# More details can be found in this issue: https://github.com/pytorch/pytorch/issues/143470.
import threading
# initiate collectives here
def init_collective_task(t):
dist.all_reduce(t)
dist.all_reduce(t)
dist.all_reduce(t)
os.environ["TORCH_NCCL_CUDA_EVENT_CACHE"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
self._create_process_group_nccl(store, self.opts())
device = self.rank_to_GPU[self.rank][0]
t = torch.rand(10, 10, device=device)
# First allreduce to initialize state.
dist.all_reduce(t)
dist.all_reduce(t)
dist.all_reduce(t)
side_thread = threading.Thread(target=init_collective_task, args=(t,))
side_thread.start()
side_thread.join()
torch.cuda.synchronize()
CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
) )

View File

@ -482,10 +482,10 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
if (cudaEventCacheEnabled) { if (cudaEventCacheEnabled) {
ncclStartEvent_ = enableTiming ncclStartEvent_ = enableTiming
? ProcessGroupNCCL::CUDAEventCache::get(device.index()) ? ProcessGroupNCCL::CUDAEventCache::get(device.index())
.create(enableTiming) ->create(enableTiming)
: nullptr; : nullptr;
ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index()) ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index())
.create(enableTiming); ->create(enableTiming);
} else { } else {
ncclStartEvent_ = enableTiming ncclStartEvent_ = enableTiming
? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault) ? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
@ -816,12 +816,17 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default;
// This is to avoid the potential deadlock caused by CudaEventDestroy. // This is to avoid the potential deadlock caused by CudaEventDestroy.
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create( std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
bool timing) { bool timing) {
// register the deleter as a callback when the WorkNCCL object is destroyed. // Register the deleter as a callback when the WorkNCCL object is destroyed.
auto deleter = [this, timing](at::cuda::CUDAEvent* event) { // Each deleter keeps a ref count to the cache object, so that even when
std::lock_guard<std::mutex> lock(this->cacheMutex_); // the thread that creates the cache is gone, the cache object won't be
// destroyed until all the events in the cache are destroyed (ref number drops
// to zero).
auto deleter = [cache = shared_from_this(),
timing](at::cuda::CUDAEvent* event) {
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
// We put the event back to the cache deque once the WorkNCCL object is // We put the event back to the cache deque once the WorkNCCL object is
// destroyed. // destroyed.
this->eventsArray_[timing ? 1 : 0].push_back(event); cache->eventsArray_[timing ? 1 : 0].push_back(event);
}; };
at::cuda::CUDAEvent* event = nullptr; at::cuda::CUDAEvent* event = nullptr;
{ {
@ -840,27 +845,22 @@ std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter)); return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
} }
ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get( std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> ProcessGroupNCCL::
at::DeviceIndex device) { CUDAEventCache::get(at::DeviceIndex device) {
// A per-thread singleton of device-to-CUDAEventCache map. // A per-thread singleton of device-to-CUDAEventCache map.
// Map is needed because events cannot be reused across devices. // Map is needed because events cannot be reused across devices.
// Per-thread ownership is needed to support multi-threaded case (instead of // Per-thread ownership is needed to support multi-threaded case (instead of
// multi-process case). // multi-process case).
static thread_local std:: static thread_local std::
map<at::DeviceIndex, ProcessGroupNCCL::CUDAEventCache> map<at::DeviceIndex, std::shared_ptr<ProcessGroupNCCL::CUDAEventCache>>
cacheDeviceMap; cacheDeviceMap;
// Check if device has already been in the map, if not, add a new entry // Check if device has already been in the map, if not, add a new entry
auto it = cacheDeviceMap.find(device); auto it = cacheDeviceMap.find(device);
if (it == cacheDeviceMap.end()) { if (it == cacheDeviceMap.end()) {
// Use in-place contruction, which avoids move or copy of the cache cacheDeviceMap.emplace(
// (the mutex of the cache is not movable/copiable) device, std::make_shared<ProcessGroupNCCL::CUDAEventCache>());
it = cacheDeviceMap.emplace_hint(
it,
std::piecewise_construct,
std::forward_as_tuple(device),
std::forward_as_tuple());
} }
return it->second; return cacheDeviceMap[device];
} }
static std::atomic<size_t> process_group_id = 0; static std::atomic<size_t> process_group_id = 0;

View File

@ -455,11 +455,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
friend class ProcessGroupNCCL; friend class ProcessGroupNCCL;
}; };
class CUDAEventCache { class CUDAEventCache
: public std::enable_shared_from_this<ProcessGroupNCCL::CUDAEventCache> {
public: public:
CUDAEventCache(); CUDAEventCache();
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing); std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
static CUDAEventCache& get(at::DeviceIndex device); static std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> get(
at::DeviceIndex device);
private: private:
std::mutex cacheMutex_; std::mutex cacheMutex_;