[2/N][SymmMem] Add MemPool allocator and tests (#161471)

(Porting most of #161008)

Hooking SymmetricMemory Allocator to MemPool so that user can create symmetric tensors with regular `torch.zeros`, `torch.arange` etc factories. Also so that our ops can have functional variants that create `out` tensors on symmetric memory.

To end users, this PR supports a python UI as follows:
```
allocator = symm_mem.get_mempool_allocator(device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
    tensor = torch.arange(numel, dtype=dtype, device=device)
```

Added tests for both use cases above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161471
Approved by: https://github.com/ngimel
ghstack dependencies: #161470
This commit is contained in:
Ke Wen
2025-08-26 16:38:22 -07:00
committed by PyTorch MergeBot
parent 0fd63fd88b
commit b291dc9684
10 changed files with 138 additions and 0 deletions

View File

@ -747,6 +747,7 @@ cc_library(
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
],
)) + torch_sources,

View File

@ -755,6 +755,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
]

View File

@ -581,6 +581,7 @@ if(USE_CUDA)
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
)
endif()

View File

@ -65,6 +65,58 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(out, group=group_name)
@skipIfRocm
def test_mempool_tensor_factory(self) -> None:
"""
Test the effectiveness of MemPool on tensor factory ops.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
src_rank = 0
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
if self.rank == src_rank:
tensor = torch.arange(numel, dtype=dtype, device=self.device)
else:
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(tensor, group=group_name)
torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name)
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
@skipIfRocm
def test_mempool_compute_ops(self) -> None:
"""
Apply MemPool context to a compute op that creates input to collective.
"""
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
dim = 1024
w = torch.ones(dim, dim, dtype=dtype, device=self.device)
x0 = torch.ones(1, dim, dtype=dtype, device=self.device)
allocator = symm_mem.get_mempool_allocator(self.device)
mempool = torch.cuda.MemPool(allocator)
with torch.cuda.use_mem_pool(mempool):
x = x0 + self.rank
y = torch.mm(x, w)
# y should be a symm tensor
torch.ops.symm_mem.nvshmem_broadcast(y, group_name)
expected = torch.mm(x0, w)
self.assertEqual(y, expected)
@skipIfRocm
def test_nvshmem_put(self) -> None:
self._init_device()

View File

@ -769,6 +769,8 @@ class _SymmetricMemory:
def set_backend(name: str) -> None: ...
@staticmethod
def get_backend(device: torch.device) -> Optional[str]: ...
@staticmethod
def get_mempool_allocator(device: torch.device) -> Any: ...
@property
def rank(self) -> int: ...
@property

View File

@ -1128,6 +1128,9 @@ This class does not support ``__members__`` property.)");
&::c10d::symmetric_memory::has_multicast_support)
.def_static("set_backend", &::c10d::symmetric_memory::set_backend)
.def_static("get_backend", &::c10d::symmetric_memory::get_backend)
.def_static(
"get_mempool_allocator",
&::c10d::symmetric_memory::get_mempool_allocator)
.def_property_readonly("rank", &SymmetricMemory::get_rank)
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
.def_property_readonly(

View File

@ -266,6 +266,28 @@ TORCH_API bool has_multicast_support(
return allocator->has_multicast_support(device_idx);
}
}
static std::unordered_map<c10::DeviceType, std::shared_ptr<c10::Allocator>>
_mempool_allocators;
void register_mempool_allocator(
c10::DeviceType device_type,
std::shared_ptr<c10::Allocator> allocator) {
_mempool_allocators[device_type] = std::move(allocator);
}
// Get allocator for MemPool given device
std::shared_ptr<c10::Allocator> get_mempool_allocator(c10::Device device) {
auto it = _mempool_allocators.find(device.type());
if (it == _mempool_allocators.end()) {
TORCH_CHECK(
false,
"SymmetricMemory MemPool did not find backend for device type ",
device.type());
}
return it->second;
}
} // namespace c10d::symmetric_memory
namespace {

View File

@ -184,4 +184,11 @@ TORCH_API void set_backend(const std::string& name);
TORCH_API std::optional<std::string> get_backend(c10::Device device);
C10_EXPORT void register_mempool_allocator(
c10::DeviceType device_type,
std::shared_ptr<c10::Allocator> allocator);
TORCH_API std::shared_ptr<c10::Allocator> get_mempool_allocator(
c10::Device device);
} // namespace c10d::symmetric_memory

View File

@ -0,0 +1,39 @@
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
namespace {
using namespace c10d::symmetric_memory;
// Alloc functor for MemPool
void* cuda_symm_alloc(size_t size, int device, void* stream) {
static auto allocator = get_allocator(c10::DeviceType::CUDA);
TORCH_CHECK(
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
// Note: this alloc functor works for the NVSHMEM and NCCL backends only,
// because only these backends takes `nullopt` for the `group` argument which
// is not given by MemPool's invocation (actually these two backends requires
// it to be `nullopt`).
return allocator->alloc(size, device, /*group_name=*/std::nullopt);
}
// Free functor for MemPool
void cuda_symm_free(void* ptr, size_t size, int device, void* stream) {
static auto allocator = get_allocator(c10::DeviceType::CUDA);
TORCH_CHECK(
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
allocator->free(ptr);
}
// Register allocator for CUDA MemPool
struct RegisterCUDAMemPoolAllocator {
RegisterCUDAMemPoolAllocator() {
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator =
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
cuda_symm_alloc, cuda_symm_free);
register_mempool_allocator(c10::DeviceType::CUDA, allocator);
}
};
static RegisterCUDAMemPoolAllocator register_cuda_mempool_allocator_;
} // namespace

View File

@ -1781,4 +1781,14 @@ def get_backend(device: _device) -> Optional[str]:
return _SymmetricMemory.get_backend(torch.device(device))
def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
r"""
Get the MemPool allocator for symmetric memory for a given device.
Args:
device (class:`torch.device` or str): the device for which to get the
MemPool allocator.
"""
return _SymmetricMemory.get_mempool_allocator(torch.device(device))
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]