[nccl symm mem] don't use arg for mempool, correctly use symmetric registration in hooks (#161238)

Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161238
Approved by: https://github.com/kwen2501, https://github.com/syed-ahmed
This commit is contained in:
Natalia Gimelshein
2025-08-25 03:09:32 +00:00
committed by PyTorch MergeBot
parent 74280d0913
commit 726dce3c94
9 changed files with 78 additions and 69 deletions

View File

@ -4153,11 +4153,8 @@ std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom,
bool symmetric)
: allocator_(allocator),
is_user_created_(is_user_created),
symmetric_(symmetric) {
bool use_on_oom)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
@ -4180,10 +4177,6 @@ MempoolId_t MemPool::id() {
return id_;
}
bool MemPool::is_symmetric() {
return symmetric_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}

View File

@ -538,8 +538,7 @@ struct C10_CUDA_API MemPool {
MemPool(
CUDACachingAllocator::CUDAAllocator* allocator = nullptr,
bool is_user_created = true,
bool use_on_oom = false,
bool symmetric = false);
bool use_on_oom = false);
MemPool(const MemPool&) = delete;
MemPool(MemPool&&) = default;
MemPool& operator=(const MemPool&) = delete;
@ -547,7 +546,6 @@ struct C10_CUDA_API MemPool {
~MemPool();
MempoolId_t id();
bool is_symmetric();
CUDACachingAllocator::CUDAAllocator* allocator();
int use_count();
c10::DeviceIndex device();
@ -559,7 +557,6 @@ struct C10_CUDA_API MemPool {
CUDACachingAllocator::CUDAAllocator* allocator_;
bool is_user_created_;
MempoolId_t id_;
bool symmetric_;
c10::DeviceIndex device_;
};

View File

@ -3099,7 +3099,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("4294967295")
class NcclRegistrationTest(MultiProcessTestCase):
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@ -3191,7 +3191,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
# Use NCCL memory allocator
# enable symmetric memory usage in NCCL
pool = torch.cuda.MemPool(backend.mem_allocator, symmetric=True)
pool = torch.cuda.MemPool(backend.mem_allocator)
# allocate memory with ncclMemAlloc
# note: symmetric kernels are not available for dtypes like torch.int64
@ -3201,10 +3201,16 @@ class NcclRegistrationTest(MultiProcessTestCase):
)
# register buffers to NCCL
backend.register_mem_pool(pool)
backend.register_mem_pool(pool, symm=True)
# allreduce now should use NVIDIA Switches
pg.allreduce(tensor).wait()
# check that further allocations are also registered
with torch.cuda.use_mem_pool(pool):
tensor = torch.arange(
1024 * 1024 * 2, device=device, dtype=torch.float32
)
pg.allreduce(tensor).wait()
torch.cuda.synchronize(device=device)
# de-register buffers from NCCL
@ -3217,7 +3223,7 @@ class NcclRegistrationTest(MultiProcessTestCase):
nccl_debug_file_content = f.read()
# if buffers were registered and symmetric kernels ran, NCCL_DEBUG
# should show successful registration in debug output
self.assertRegex(nccl_debug_file_content, "[Symmetric]")
self.assertRegex(nccl_debug_file_content, "Symmetric")
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):

View File

@ -2336,13 +2336,10 @@ class _MemPool:
allocator: _cuda_CUDAAllocator | None = None,
is_user_created: _bool = True,
use_on_oom: _bool = False,
symmetric: _bool = False,
) -> None: ...
@property
def id(self) -> tuple[_int, _int]: ...
@property
def is_symmetric(self) -> _bool: ...
@property
def allocator(self) -> _cuda_CUDAAllocator | None: ...
def use_count(self) -> _int: ...

View File

@ -16,15 +16,12 @@ void THCPMemPool_init(PyObject* module) {
.def(
py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created,
bool use_on_oom,
bool symmetric) {
bool use_on_oom) {
torch::utils::device_lazy_init(at::kCUDA);
return std::make_shared<::c10::cuda::MemPool>(
allocator, is_user_created, use_on_oom, symmetric);
allocator, is_user_created, use_on_oom);
}))
.def_property_readonly("id", &::c10::cuda::MemPool::id)
.def_property_readonly(
"is_symmetric", &::c10::cuda::MemPool::is_symmetric)
.def_property_readonly("allocator", &::c10::cuda::MemPool::allocator)
.def("use_count", &::c10::cuda::MemPool::use_count);
}

View File

@ -291,8 +291,12 @@ bool shouldAllCommunicatorsRegisterAllTensors() {
// - This map has also to be maintained as global variable since the register
// hooks are called outside the scope of any PG, thus we need traverse
// communicators in all PGs.
using MemPoolSet = std::
unordered_set<c10::cuda::MempoolId_t, c10::hash<c10::cuda::MempoolId_t>>;
// MemPoolSet has ids of mempools used with this communicator, and whether they
// were registered with window APIs or not
using MemPoolSet = std::unordered_set<
std::tuple<c10::cuda::MempoolId_t, bool>,
c10::hash<std::tuple<c10::cuda::MempoolId_t, bool>>>;
static std::unordered_map<std::shared_ptr<NCCLComm>, MemPoolSet>
ncclCommMemPoolMap;
static std::mutex ncclCommMemPoolMapMutex;
@ -310,10 +314,23 @@ static void cacheAllocatorRegisterHook(
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) {
if (te.device_ == ncclComm->getDeviceIndex()) {
if (shouldAllCommunicatorsRegisterAllTensors() ||
memPools.find(te.mempool_) != memPools.end()) {
bool symm = false;
bool should_register = shouldAllCommunicatorsRegisterAllTensors();
auto it =
std::find_if(memPools.begin(), memPools.end(), [&](const auto& tup) {
return std::get<0>(tup) == te.mempool_;
});
if (it != memPools.end()) {
should_register = true;
symm = std::get<1>(*it);
}
if (should_register) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->registerSegment(reinterpret_cast<void*>(te.addr_), te.size_);
ncclComm->registerSegment(
reinterpret_cast<void*>(te.addr_),
te.size_,
/*errorOnRereg*/ false,
/*window*/ symm);
}
}
}
@ -330,10 +347,19 @@ static void cacheAllocatorDeregisterHook(
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
for (auto& [ncclComm, memPools] : ncclCommMemPoolMap) {
if (te.device_ == ncclComm->getDeviceIndex()) {
if (shouldAllCommunicatorsRegisterAllTensors() ||
memPools.find(te.mempool_) != memPools.end()) {
bool symm = false;
bool should_register = shouldAllCommunicatorsRegisterAllTensors();
auto it =
std::find_if(memPools.begin(), memPools.end(), [&](const auto& tup) {
return std::get<0>(tup) == te.mempool_;
});
if (it != memPools.end()) {
should_register = true;
symm = std::get<1>(*it);
}
if (should_register) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_));
ncclComm->deregisterSegment(reinterpret_cast<void*>(te.addr_), symm);
}
}
}
@ -968,8 +994,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
const std::string OFF = "OFF";
std::string torch_distributed_debug =
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
<< "size: " << size << ", global rank: " << globalRank()
LOG(INFO) << logPrefix()
<< "ProcessGroupNCCL initialization options: " << "size: " << size
<< ", global rank: " << globalRank()
<< ", TIMEOUT(ms): " << options_->timeout.count()
<< ", USE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
@ -1089,7 +1116,7 @@ ErrorType ProcessGroupNCCL::getError() {
return error_;
}
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool, bool symm) {
const auto key = std::to_string(pool->device());
LOG(INFO) << logPrefix()
<< "Performing NCCL user buffer registration for all buffers in "
@ -1101,24 +1128,15 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
DistBackendError,
"NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue");
}
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
{
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
auto iter = ncclCommMemPoolMap.find(ncclComm);
iter->second.insert(pool->id());
iter->second.insert(std::make_tuple(pool->id(), symm));
}
// We must ensure we're listening for allocator trace events in order to
// register future segments allocated in this pool (this call is idempotent).
attachAllocatorHooks();
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
// TODO:
// if(pool->is_symmetric()) {
// Allgather to verify len(mempool.snapshot.segments) matches across GPUs
// Allgather to verify mempool.alloc_request_counter matches across GPUs
// add alloc_request_counter per mempool (How many allocations a mempool has
// served during its lifetime) this should guarantee pool is used in a
// symmetric/SPMD manner
// }
for (const auto& segmentInfo : snapshot.segments) {
TORCH_INTERNAL_ASSERT(
segmentInfo.device == pool->device(),
@ -1128,31 +1146,35 @@ void ProcessGroupNCCL::registerMemPool(c10::cuda::MemPool* pool) {
reinterpret_cast<void*>(segmentInfo.address),
segmentInfo.total_size,
/*errorOnRereg=*/false, // ignores reregistration error
/*window=*/pool->is_symmetric()); // whether to use NCCL symmetric
// memory
/*window*/ symm); // whether to use NCCL symmetric memory
}
}
void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
const auto key = std::to_string(pool->device());
auto device = at::Device(at::DeviceType::CUDA, pool->device());
LOG(INFO) << logPrefix()
<< "Performing NCCL user buffer deregistration for all buffers in "
<< "MemPool: " << pool->id() << ", device index: " << key
<< ", i am " << this;
auto ncclComm = getNCCLComm(key);
if (ncclComm == nullptr) {
// HACK: currently we are using this function for NVLS
// reductions, and that's why using OpType::ALLREDUCE.
// If we end up using this API for zero-copy P2P, we might
// need to refactor and account for different OpType.
ncclComm = initNCCLComm(key, device, OpType::ALLREDUCE);
C10_THROW_ERROR(
DistBackendError,
"NCCL communicator has not been initialized before mem pool creation. You can pass `device_id` to init_process_group -- one way of eager initialization -- to work around this issue");
}
TORCH_INTERNAL_ASSERT(ncclComm != nullptr);
bool symm;
{
std::lock_guard<std::mutex> lock(ncclCommMemPoolMapMutex);
auto iter = ncclCommMemPoolMap.find(ncclComm);
iter->second.erase(pool->id());
auto mempool_it = std::find_if(
iter->second.begin(), iter->second.end(), [&](const auto& tup) {
return std::get<0>(tup) == pool->id();
});
TORCH_CHECK(
mempool_it != iter->second.end(),
"Trying to unregister not previously registered pool");
symm = std::get<1>(*mempool_it);
iter->second.erase(mempool_it);
}
auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(pool->id());
for (const auto& segmentInfo : snapshot.segments) {
@ -1161,7 +1183,7 @@ void ProcessGroupNCCL::deregisterMemPool(c10::cuda::MemPool* pool) {
"Mismatch between CUDA memory segment device and pool's device");
// NOLINTNEXTLINE(performance-no-int-to-ptr)
ncclComm->deregisterSegment(
reinterpret_cast<void*>(segmentInfo.address), pool->is_symmetric());
reinterpret_cast<void*>(segmentInfo.address), symm);
}
}
@ -5749,7 +5771,7 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
// Pool is created
memPool_ = std::make_unique<c10::cuda::MemPool>(allocator);
// Register so that we call ncclCommRegister on all new allocations
registerMemPool(memPool_.get());
registerMemPool(memPool_.get(), /*symmetric*/ false);
LOG(INFO) << logPrefix() << "Created memory pool";
}

View File

@ -1002,7 +1002,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Performs NCCL user buffer registration for all buffers in
// the given MemPool
void registerMemPool(c10::cuda::MemPool* pool);
void registerMemPool(c10::cuda::MemPool* pool, bool symm = false);
// Performs NCCL user buffer de-registration for all buffers in
// the given MemPool

View File

@ -3334,7 +3334,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def(
"perform_nocolor_split",
&::c10d::ProcessGroupNCCL::performNocolorSplit)
.def("register_mem_pool", &::c10d::ProcessGroupNCCL::registerMemPool)
.def(
"register_mem_pool",
&::c10d::ProcessGroupNCCL::registerMemPool,
py::arg("pool"),
py::arg("symm") = false)
.def(
"deregister_mem_pool",
&::c10d::ProcessGroupNCCL::deregisterMemPool)

View File

@ -1169,28 +1169,21 @@ class MemPool(_MemPool):
use_on_oom(bool): a bool that indicates if this pool can be used
as a last resort if a memory allocation outside of the pool fails due
to Out Of Memory. This is False by default.
symmetric(bool): a bool that indicates if this pool is symmetrical
across ranks. This is False by default.
"""
def __init__(
self,
allocator: Optional[_cuda_CUDAAllocator] = None,
use_on_oom: bool = False,
symmetric: bool = False,
):
super().__init__(allocator, True, use_on_oom, symmetric)
super().__init__(allocator, True, use_on_oom)
@property
def id(self) -> tuple[int, int]:
r"""Returns the ID of this pool as a tuple of two ints."""
return super().id
@property
def is_symmetric(self) -> bool:
r"""Returns whether this pool is used for NCCL's symmetric memory."""
return super().is_symmetric
@property
def allocator(self) -> Optional[_cuda_CUDAAllocator]:
r"""Returns the allocator this MemPool routes allocations to."""