mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[cca] [c10d] Refactor CUDAEventCache into separate files (#158616)
Summary: Refactored CUDAEventCache from ProcessGroupNCCL.hpp/.cpp into dedicated header and implementation files for better code organization and maintainability. Split out CUDAEventCache into: - New header file: CUDAEventCache.hpp - New implementation file: CUDAEventCache.cpp - Updated build_variables.bzl to include the new file This change improves code maintainability, readability, and follows better code organization practices. --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Test Plan: Verified build with: ``` buck build //caffe2/test/distributed:c10d ``` --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158616 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
90b082e207
commit
ab557421a4
@ -738,6 +738,7 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/UCCTracing.cpp",
|
||||
"torch/csrc/distributed/c10d/UCCUtils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
|
||||
"torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/utils.cpp",
|
||||
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
|
@ -767,8 +767,8 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
|
||||
}
|
||||
|
||||
// Test that the CUDAEventCache can be used to create CUDA events and reuse.
|
||||
auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
|
||||
auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
|
||||
auto event1 = c10d::CUDAEventCache::get(1)->create(true);
|
||||
auto event2 = c10d::CUDAEventCache::get(1)->create(false);
|
||||
|
||||
auto event1_ptr = event1.get();
|
||||
auto event2_ptr = event2.get();
|
||||
@ -777,14 +777,14 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
|
||||
event2 = nullptr;
|
||||
|
||||
// Test that the CUDAEventCache is indeed reused.
|
||||
auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true);
|
||||
auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false);
|
||||
auto event3 = c10d::CUDAEventCache::get(2)->create(true);
|
||||
auto event4 = c10d::CUDAEventCache::get(2)->create(false);
|
||||
// The cache has been used up, new events should be created.
|
||||
auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
|
||||
auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
|
||||
auto event5 = c10d::CUDAEventCache::get(1)->create(true);
|
||||
auto event6 = c10d::CUDAEventCache::get(1)->create(false);
|
||||
// The cache has been used up, new events should be created.
|
||||
auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
|
||||
auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
|
||||
auto event7 = c10d::CUDAEventCache::get(1)->create(true);
|
||||
auto event8 = c10d::CUDAEventCache::get(1)->create(false);
|
||||
EXPECT_NE(event1_ptr, event3.get());
|
||||
EXPECT_NE(event2_ptr, event4.get());
|
||||
EXPECT_EQ(event1_ptr, event5.get());
|
||||
|
@ -519,11 +519,9 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
|
||||
// DEFAULT_FLAGS = cudaEventDisableTiming.
|
||||
if (cudaEventCacheEnabled) {
|
||||
ncclStartEvent_ = enableTiming
|
||||
? ProcessGroupNCCL::CUDAEventCache::get(device.index())
|
||||
->create(enableTiming)
|
||||
? CUDAEventCache::get(device.index())->create(enableTiming)
|
||||
: nullptr;
|
||||
ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index())
|
||||
->create(enableTiming);
|
||||
ncclEndEvent_ = CUDAEventCache::get(device.index())->create(enableTiming);
|
||||
} else {
|
||||
ncclStartEvent_ = enableTiming
|
||||
? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
|
||||
@ -860,61 +858,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
|
||||
}
|
||||
}
|
||||
|
||||
ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default;
|
||||
|
||||
// CUDA event is used to record the start/end of one Work.
|
||||
// Instead of let the CUDA event gets destroyed, we now reuse it after the Work
|
||||
// has been erased from workMetaList_.
|
||||
// This is to avoid the potential deadlock caused by CudaEventDestroy.
|
||||
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
|
||||
bool timing) {
|
||||
// Register the deleter as a callback when the WorkNCCL object is destroyed.
|
||||
// Each deleter keeps a ref count to the cache object, so that even when
|
||||
// the thread that creates the cache is gone, the cache object won't be
|
||||
// destroyed until all the events in the cache are destroyed (ref number drops
|
||||
// to zero).
|
||||
auto deleter = [cache = shared_from_this(),
|
||||
timing](at::cuda::CUDAEvent* event) {
|
||||
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
|
||||
// We put the event back to the cache deque once the WorkNCCL object is
|
||||
// destroyed.
|
||||
cache->eventsArray_[timing ? 1 : 0].push_back(event);
|
||||
};
|
||||
at::cuda::CUDAEvent* event = nullptr;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(cacheMutex_);
|
||||
auto& events = eventsArray_[timing ? 1 : 0];
|
||||
// If we still have events in the cache, we reuse it. Otherwise, we create a
|
||||
// new one.
|
||||
if (!events.empty()) {
|
||||
event = events.front();
|
||||
events.pop_front();
|
||||
} else {
|
||||
event = new at::cuda::CUDAEvent(
|
||||
timing ? cudaEventDefault : cudaEventDisableTiming);
|
||||
}
|
||||
}
|
||||
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
|
||||
}
|
||||
|
||||
std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> ProcessGroupNCCL::
|
||||
CUDAEventCache::get(at::DeviceIndex device) {
|
||||
// A per-thread singleton of device-to-CUDAEventCache map.
|
||||
// Map is needed because events cannot be reused across devices.
|
||||
// Per-thread ownership is needed to support multi-threaded case (instead of
|
||||
// multi-process case).
|
||||
static thread_local std::
|
||||
map<at::DeviceIndex, std::shared_ptr<ProcessGroupNCCL::CUDAEventCache>>
|
||||
cacheDeviceMap;
|
||||
// Check if device has already been in the map, if not, add a new entry
|
||||
auto it = cacheDeviceMap.find(device);
|
||||
if (it == cacheDeviceMap.end()) {
|
||||
cacheDeviceMap.emplace(
|
||||
device, std::make_shared<ProcessGroupNCCL::CUDAEventCache>());
|
||||
}
|
||||
return cacheDeviceMap[device];
|
||||
}
|
||||
|
||||
static std::atomic<size_t> process_group_id = 0;
|
||||
|
||||
constexpr const char* MULTI_DEVICE_ERROR_MSG =
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp>
|
||||
#include <torch/csrc/distributed/c10d/logger.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp>
|
||||
|
||||
@ -503,23 +504,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
class CUDAEventCache
|
||||
: public std::enable_shared_from_this<ProcessGroupNCCL::CUDAEventCache> {
|
||||
public:
|
||||
CUDAEventCache();
|
||||
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
|
||||
static std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> get(
|
||||
at::DeviceIndex device);
|
||||
|
||||
private:
|
||||
std::mutex cacheMutex_;
|
||||
// NOTE: We intentionally store raw pointers so that
|
||||
// we do not attempt to destroy the event objects on process exit,
|
||||
// because cuda may be gone.
|
||||
std::array<std::deque<at::cuda::CUDAEvent*>, 2>
|
||||
eventsArray_; // 0 for timing=false, 1 for timing=true
|
||||
};
|
||||
|
||||
struct Options : Backend::Options {
|
||||
// NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
|
||||
// operations. This is only used when blockingWait_ is enabled.
|
||||
|
58
torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp
Normal file
58
torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp>
|
||||
#include <map>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
CUDAEventCache::CUDAEventCache() = default;
|
||||
|
||||
// CUDA event is used to record the start/end of one Work.
|
||||
// Instead of let the CUDA event gets destroyed, we now reuse it after the Work
|
||||
// has been erased from workMetaList_.
|
||||
// This is to avoid the potential deadlock caused by CudaEventDestroy.
|
||||
std::shared_ptr<at::cuda::CUDAEvent> CUDAEventCache::create(bool timing) {
|
||||
// Register the deleter as a callback when the WorkNCCL object is destroyed.
|
||||
// Each deleter keeps a ref count to the cache object, so that even when
|
||||
// the thread that creates the cache is gone, the cache object won't be
|
||||
// destroyed until all the events in the cache are destroyed (ref number drops
|
||||
// to zero).
|
||||
auto deleter = [cache = shared_from_this(),
|
||||
timing](at::cuda::CUDAEvent* event) {
|
||||
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
|
||||
// We put the event back to the cache deque once the WorkNCCL object is
|
||||
// destroyed.
|
||||
cache->eventsArray_[timing ? 1 : 0].push_back(event);
|
||||
};
|
||||
at::cuda::CUDAEvent* event = nullptr;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(cacheMutex_);
|
||||
auto& events = eventsArray_[timing ? 1 : 0];
|
||||
// If we still have events in the cache, we reuse it. Otherwise, we create a
|
||||
// new one.
|
||||
if (!events.empty()) {
|
||||
event = events.front();
|
||||
events.pop_front();
|
||||
} else {
|
||||
event = new at::cuda::CUDAEvent(
|
||||
timing ? cudaEventDefault : cudaEventDisableTiming);
|
||||
}
|
||||
}
|
||||
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
|
||||
}
|
||||
|
||||
std::shared_ptr<CUDAEventCache> CUDAEventCache::get(at::DeviceIndex device) {
|
||||
// A per-thread singleton of device-to-CUDAEventCache map.
|
||||
// Map is needed because events cannot be reused across devices.
|
||||
// Per-thread ownership is needed to support multi-threaded case (instead of
|
||||
// multi-process case).
|
||||
static thread_local std::map<at::DeviceIndex, std::shared_ptr<CUDAEventCache>>
|
||||
cacheDeviceMap;
|
||||
// Check if device has already been in the map, if not, add a new entry
|
||||
auto it = cacheDeviceMap.find(device);
|
||||
if (it == cacheDeviceMap.end()) {
|
||||
cacheDeviceMap.emplace(device, std::make_shared<CUDAEventCache>());
|
||||
}
|
||||
return cacheDeviceMap[device];
|
||||
}
|
||||
|
||||
} // namespace c10d
|
29
torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp
Normal file
29
torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
class TORCH_API CUDAEventCache
|
||||
: public std::enable_shared_from_this<CUDAEventCache> {
|
||||
public:
|
||||
CUDAEventCache();
|
||||
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
|
||||
static std::shared_ptr<CUDAEventCache> get(at::DeviceIndex device);
|
||||
|
||||
private:
|
||||
std::mutex cacheMutex_;
|
||||
// NOTE: We intentionally store raw pointers so that
|
||||
// we do not attempt to destroy the event objects on process exit,
|
||||
// because cuda may be gone.
|
||||
std::array<std::deque<at::cuda::CUDAEvent*>, 2>
|
||||
eventsArray_; // 0 for timing=false, 1 for timing=true
|
||||
};
|
||||
|
||||
} // namespace c10d
|
Reference in New Issue
Block a user