mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nccl symm mem] don't use arg for mempool, correctly use symmetric registration in hooks (#161238)
Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/161238 Approved by: https://github.com/kwen2501, https://github.com/syed-ahmed
This commit is contained in:
committed by
PyTorch MergeBot
parent
74280d0913
commit
726dce3c94
@ -4153,11 +4153,8 @@ std::atomic<CaptureId_t> MemPool::uuid_{1};
|
||||
MemPool::MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom,
|
||||
bool symmetric)
|
||||
: allocator_(allocator),
|
||||
is_user_created_(is_user_created),
|
||||
symmetric_(symmetric) {
|
||||
bool use_on_oom)
|
||||
: allocator_(allocator), is_user_created_(is_user_created) {
|
||||
if (is_user_created_) {
|
||||
id_ = {0, uid_++};
|
||||
} else {
|
||||
@ -4180,10 +4177,6 @@ MempoolId_t MemPool::id() {
|
||||
return id_;
|
||||
}
|
||||
|
||||
bool MemPool::is_symmetric() {
|
||||
return symmetric_;
|
||||
}
|
||||
|
||||
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
|
||||
return allocator_;
|
||||
}
|
||||
|
@ -538,8 +538,7 @@ struct C10_CUDA_API MemPool {
|
||||
MemPool(
|
||||
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
|
||||
bool is_user_created = true,
|
||||
bool use_on_oom = false,
|
||||
bool symmetric = false);
|
||||
bool use_on_oom = false);
|
||||
MemPool(const MemPool&) = delete;
|
||||
MemPool(MemPool&&) = default;
|
||||
MemPool& operator=(const MemPool&) = delete;
|
||||
@ -547,7 +546,6 @@ struct C10_CUDA_API MemPool {
|
||||
~MemPool();
|
||||
|
||||
MempoolId_t id();
|
||||
bool is_symmetric();
|
||||
CUDACachingAllocator::CUDAAllocator* allocator();
|
||||
int use_count();
|
||||
c10::DeviceIndex device();
|
||||
@ -559,7 +557,6 @@ struct C10_CUDA_API MemPool {
|
||||
CUDACachingAllocator::CUDAAllocator* allocator_;
|
||||
bool is_user_created_;
|
||||
MempoolId_t id_;
|
||||
bool symmetric_;
|
||||
c10::DeviceIndex device_;
|
||||
};
|
||||
|
||||
|
@ -3099,7 +3099,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
self._run_invalid_nccl_blocking_wait_env("4294967295")
|
||||
|
||||
|
||||
class NcclRegistrationTest(MultiProcessTestCase):
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
@ -3191,7 +3191,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
|
||||
# Use NCCL memory allocator
|
||||
# enable symmetric memory usage in NCCL
|
||||
pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True)
|
||||
pool = torch.cuda.MemPool(backend.mem_allocator)
|
||||
|
||||
# allocate memory with ncclMemAlloc
|
||||
# note: symmetric kernels are not available for dtypes like torch.int64
|
||||
@ -3201,10 +3201,16 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
)
|
||||
|
||||
# register buffers to NCCL
|
||||
backend.register_mem_pool(pool)
|
||||
backend.register_mem_pool(pool, symm=True)
|
||||
|
||||
# allreduce now should use NVIDIA Switches
|
||||
pg.allreduce(tensor).wait()
|
||||
# check that further allocations are also registered
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
tensor = torch.arange(
|
||||
1024 * 1024 * 2, device=device, dtype=torch.float32
|
||||
)
|
||||
pg.allreduce(tensor).wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
|
||||
# de-register buffers from NCCL
|
||||
@ -3217,7 +3223,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
|
||||
nccl_debug_file_content = f.read()
|
||||
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
|
||||
# should show successful registration in debug output
|
||||
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
|
||||
self.assertRegex(nccl_debug_file_content, "Symmetric")
|
||||
|
||||
|
||||
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
|
@ -2336,13 +2336,10 @@ class _MemPool:
|
||||
allocator: _cuda_CUDAAllocator | None = None,
|
||||
is_user_created: _bool = True,
|
||||
use_on_oom: _bool = False,
|
||||
symmetric: _bool = False,
|
||||
) -> None: ...
|
||||
@property
|
||||
def id(self) -> tuple[_int, _int]: ...
|
||||
@property
|
||||
def is_symmetric(self) -> _bool: ...
|
||||
@property
|
||||
def allocator(self) -> _cuda_CUDAAllocator | None: ...
|
||||
def use_count(self) -> _int: ...
|
||||
|
||||
|
@ -16,15 +16,12 @@ void THCPMemPool_init(PyObject* module) {
|
||||
.def(
|
||||
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
|
||||
bool is_user_created,
|
||||
bool use_on_oom,
|
||||
bool symmetric) {
|
||||
bool use_on_oom) {
|
||||
torch::utils::device_lazy_init(at::kCUDA);
|
||||
return std::make_shared<::c10::cuda::MemPool>(
|
||||
allocator, is_user_created, use_on_oom, symmetric);
|
||||
allocator, is_user_created, use_on_oom);
|
||||
}))
|
||||
.def_property_readonly("id", &::c10::cuda::MemPool::id)
|
||||
.def_property_readonly(
|
||||
"is_symmetric", &::c10::cuda::MemPool::is_symmetric)
|
||||
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
|
||||
.def("use_count", &::c10::cuda::MemPool::use_count);
|
||||
}
|
||||
|
@ -291,8 +291,12 @@ bool shouldAllCommunicatorsRegisterAllTensors() {
|
||||
// - This map has also to be maintained as global variable since the register
|
||||
// hooks are called outside the scope of any PG, thus we need traverse
|
||||
// communicators in all PGs.
|
||||
using MemPoolSet = std::
|
||||
unordered_set<c10::cuda::MempoolId_t, c10::hash<c10::cuda::MempoolId_t>>;
|
||||
|
||||
// MemPoolSet has ids of mempools used with this communicator, and whether they
|
||||
// were registered with window APIs or not
|
||||
using MemPoolSet = std::unordered_set<
|
||||
std::tuple<c10::cuda::MempoolId_t, bool>,
|
||||
c10::hash<std::tuple<c10::cuda::MempoolId_t, bool>>>;
|
||||
static std::unordered_map<std::shared_ptr<NCCLComm>, MemPoolSet>
|
||||
ncclCommMemPoolMap;
|
||||
static std::mutex ncclCommMemPoolMapMutex;
|
||||
@ -310,10 +314,23 @@ static void cacheAllocatorRegisterHook(
|
||||
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
|
||||
for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) {
|
||||
if (te.device_ == ncclComm->getDeviceIndex()) {
|
||||
if (shouldAllCommunicatorsRegisterAllTensors() ||
|
||||
memPools.find(te.mempool_) != memPools.end()) {
|
||||
bool symm = false;
|
||||
bool should_register = shouldAllCommunicatorsRegisterAllTensors();
|
||||
auto it =
|
||||
std::find_if(memPools.begin(), memPools.end(), [&](const auto& tup) {
|
||||
return std::get<0>(tup) == te.mempool_;
|
||||
});
|
||||
if (it != memPools.end()) {
|
||||
should_register = true;
|
||||
symm = std::get<1>(*it);
|
||||
}
|
||||
if (should_register) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
|
||||
ncclComm->registerSegment(
|
||||
reinterpret_cast<void*>(te.addr_),
|
||||
te.size_,
|
||||
/*errorOnRereg*/ false,
|
||||
/*window*/ symm);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -330,10 +347,19 @@ static void cacheAllocatorDeregisterHook(
|
||||
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
|
||||
for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) {
|
||||
if (te.device_ == ncclComm->getDeviceIndex()) {
|
||||
if (shouldAllCommunicatorsRegisterAllTensors() ||
|
||||
memPools.find(te.mempool_) != memPools.end()) {
|
||||
bool symm = false;
|
||||
bool should_register = shouldAllCommunicatorsRegisterAllTensors();
|
||||
auto it =
|
||||
std::find_if(memPools.begin(), memPools.end(), [&](const auto& tup) {
|
||||
return std::get<0>(tup) == te.mempool_;
|
||||
});
|
||||
if (it != memPools.end()) {
|
||||
should_register = true;
|
||||
symm = std::get<1>(*it);
|
||||
}
|
||||
if (should_register) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
|
||||
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_), symm);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -968,8 +994,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
const std::string OFF = "OFF";
|
||||
std::string torch_distributed_debug =
|
||||
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
|
||||
<< "size: " << size << ", global rank: " << globalRank()
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "ProcessGroupNCCL initialization options: " << "size: " << size
|
||||
<< ", global rank: " << globalRank()
|
||||
<< ", TIMEOUT(ms): " << options_->timeout.count()
|
||||
<< ", USE_HIGH_PRIORITY_STREAM: "
|
||||
<< options_->is_high_priority_stream
|
||||
@ -1089,7 +1116,7 @@ ErrorType ProcessGroupNCCL::getError() {
|
||||
return error_;
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
|
||||
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) {
|
||||
const auto key = std::to_string(pool->device());
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Performing NCCL user buffer registration for all buffers in "
|
||||
@ -1101,24 +1128,15 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
|
||||
DistBackendError,
|
||||
"NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
|
||||
auto iter = ncclCommMemPoolMap.find(ncclComm);
|
||||
iter->second.insert(pool->id());
|
||||
iter->second.insert(std::make_tuple(pool->id(), symm));
|
||||
}
|
||||
// We must ensure we're listening for allocator trace events in order to
|
||||
// register future segments allocated in this pool (this call is idempotent).
|
||||
attachAllocatorHooks();
|
||||
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
|
||||
// TODO:
|
||||
// if(pool->is_symmetric()) {
|
||||
// Allgather to verify len(mempool.snapshot.segments) matches across GPUs
|
||||
// Allgather to verify mempool.alloc_request_counter matches across GPUs
|
||||
// add alloc_request_counter per mempool (How many allocations a mempool has
|
||||
// served during its lifetime) this should guarantee pool is used in a
|
||||
// symmetric/SPMD manner
|
||||
// }
|
||||
for (const auto& segmentInfo : snapshot.segments) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
segmentInfo.device == pool->device(),
|
||||
@ -1128,31 +1146,35 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
|
||||
reinterpret_cast<void*>(segmentInfo.address),
|
||||
segmentInfo.total_size,
|
||||
/*errorOnRereg=*/false, // ignores reregistration error
|
||||
/*window=*/pool->is_symmetric()); // whether to use NCCL symmetric
|
||||
// memory
|
||||
/*window*/ symm); // whether to use NCCL symmetric memory
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
|
||||
const auto key = std::to_string(pool->device());
|
||||
auto device = at::Device(at::DeviceType::CUDA, pool->device());
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "Performing NCCL user buffer deregistration for all buffers in "
|
||||
<< "MemPool: " << pool->id() << ", device index: " << key
|
||||
<< ", i am " << this;
|
||||
auto ncclComm = getNCCLComm(key);
|
||||
if (ncclComm == nullptr) {
|
||||
// HACK: currently we are using this function for NVLS
|
||||
// reductions, and that's why using OpType::ALLREDUCE.
|
||||
// If we end up using this API for zero-copy P2P, we might
|
||||
// need to refactor and account for different OpType.
|
||||
ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE);
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
"NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
|
||||
bool symm;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
|
||||
auto iter = ncclCommMemPoolMap.find(ncclComm);
|
||||
iter->second.erase(pool->id());
|
||||
auto mempool_it = std::find_if(
|
||||
iter->second.begin(), iter->second.end(), [&](const auto& tup) {
|
||||
return std::get<0>(tup) == pool->id();
|
||||
});
|
||||
TORCH_CHECK(
|
||||
mempool_it != iter->second.end(),
|
||||
"Trying to unregister not previously registered pool");
|
||||
symm = std::get<1>(*mempool_it);
|
||||
iter->second.erase(mempool_it);
|
||||
}
|
||||
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
|
||||
for (const auto& segmentInfo : snapshot.segments) {
|
||||
@ -1161,7 +1183,7 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
|
||||
"Mismatch between CUDA memory segment device and pool's device");
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
ncclComm->deregisterSegment(
|
||||
reinterpret_cast<void*>(segmentInfo.address), pool->is_symmetric());
|
||||
reinterpret_cast<void*>(segmentInfo.address), symm);
|
||||
}
|
||||
}
|
||||
|
||||
@ -5749,7 +5771,7 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
// Pool is created
|
||||
memPool_ = std::make_unique<c10::cuda::MemPool>(allocator);
|
||||
// Register so that we call ncclCommRegister on all new allocations
|
||||
registerMemPool(memPool_.get());
|
||||
registerMemPool(memPool_.get(), /*symmetric*/ false);
|
||||
LOG(INFO) << logPrefix() << "Created memory pool";
|
||||
}
|
||||
|
||||
|
@ -1002,7 +1002,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
// Performs NCCL user buffer registration for all buffers in
|
||||
// the given MemPool
|
||||
void registerMemPool(c10::cuda::MemPool* pool);
|
||||
void registerMemPool(c10::cuda::MemPool* pool, bool symm = false);
|
||||
|
||||
// Performs NCCL user buffer de-registration for all buffers in
|
||||
// the given MemPool
|
||||
|
@ -3334,7 +3334,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
||||
.def(
|
||||
"perform_nocolor_split",
|
||||
&::c10d::ProcessGroupNCCL::performNocolorSplit)
|
||||
.def("register_mem_pool", &::c10d::ProcessGroupNCCL::registerMemPool)
|
||||
.def(
|
||||
"register_mem_pool",
|
||||
&::c10d::ProcessGroupNCCL::registerMemPool,
|
||||
py::arg("pool"),
|
||||
py::arg("symm") = false)
|
||||
.def(
|
||||
"deregister_mem_pool",
|
||||
&::c10d::ProcessGroupNCCL::deregisterMemPool)
|
||||
|
@ -1169,28 +1169,21 @@ class MemPool(_MemPool):
|
||||
use_on_oom(bool): a bool that indicates if this pool can be used
|
||||
as a last resort if a memory allocation outside of the pool fails due
|
||||
to Out Of Memory. This is False by default.
|
||||
symmetric(bool): a bool that indicates if this pool is symmetrical
|
||||
across ranks. This is False by default.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allocator: Optional[_cuda_CUDAAllocator] = None,
|
||||
use_on_oom: bool = False,
|
||||
symmetric: bool = False,
|
||||
):
|
||||
super().__init__(allocator, True, use_on_oom, symmetric)
|
||||
super().__init__(allocator, True, use_on_oom)
|
||||
|
||||
@property
|
||||
def id(self) -> tuple[int, int]:
|
||||
r"""Returns the ID of this pool as a tuple of two ints."""
|
||||
return super().id
|
||||
|
||||
@property
|
||||
def is_symmetric(self) -> bool:
|
||||
r"""Returns whether this pool is used for NCCL's symmetric memory."""
|
||||
return super().is_symmetric
|
||||
|
||||
@property
|
||||
def allocator(self) -> Optional[_cuda_CUDAAllocator]:
|
||||
r"""Returns the allocator this MemPool routes allocations to."""
|
||||
|
Reference in New Issue
Block a user