diff --git a/BUILD.bazel b/BUILD.bazel index 58ebc31e243c..d4202e7a2c1e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -747,6 +747,7 @@ cc_library( "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp", + "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index fb9314e2c7a0..fd53c9e8aa12 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -755,6 +755,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", + "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", ] diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 375228a5b75e..86a57264d253 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -581,6 +581,7 @@ if(USE_CUDA) ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_nvshmem.py b/test/distributed/test_nvshmem.py index 64b8062b6098..f8567cdad077 100644 --- a/test/distributed/test_nvshmem.py +++ b/test/distributed/test_nvshmem.py @@ -65,6 +65,58 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest): out = symm_mem.empty(numel, dtype=dtype, device=self.device) symm_mem.rendezvous(out, group=group_name) + @skipIfRocm + def test_mempool_tensor_factory(self) -> None: + """ + Test the effectiveness of MemPool on tensor factory ops. + """ + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + numel = 1024 + src_rank = 0 + + allocator = symm_mem.get_mempool_allocator(self.device) + mempool = torch.cuda.MemPool(allocator) + + with torch.cuda.use_mem_pool(mempool): + if self.rank == src_rank: + tensor = torch.arange(numel, dtype=dtype, device=self.device) + else: + tensor = torch.zeros(numel, dtype=dtype, device=self.device) + + symm_mem.rendezvous(tensor, group=group_name) + torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name) + self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device)) + + @skipIfRocm + def test_mempool_compute_ops(self) -> None: + """ + Apply MemPool context to a compute op that creates input to collective. + """ + self._init_device() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + dtype = torch.float + dim = 1024 + w = torch.ones(dim, dim, dtype=dtype, device=self.device) + x0 = torch.ones(1, dim, dtype=dtype, device=self.device) + + allocator = symm_mem.get_mempool_allocator(self.device) + mempool = torch.cuda.MemPool(allocator) + + with torch.cuda.use_mem_pool(mempool): + x = x0 + self.rank + y = torch.mm(x, w) + + # y should be a symm tensor + torch.ops.symm_mem.nvshmem_broadcast(y, group_name) + expected = torch.mm(x0, w) + self.assertEqual(y, expected) + @skipIfRocm def test_nvshmem_put(self) -> None: self._init_device() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 72fde27d0257..0622cdf461aa 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -769,6 +769,8 @@ class _SymmetricMemory: def set_backend(name: str) -> None: ... @staticmethod def get_backend(device: torch.device) -> Optional[str]: ... + @staticmethod + def get_mempool_allocator(device: torch.device) -> Any: ... @property def rank(self) -> int: ... @property diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a0904a814637..fd612d46abad 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1128,6 +1128,9 @@ This class does not support ``__members__`` property.)"); &::c10d::symmetric_memory::has_multicast_support) .def_static("set_backend", &::c10d::symmetric_memory::set_backend) .def_static("get_backend", &::c10d::symmetric_memory::get_backend) + .def_static( + "get_mempool_allocator", + &::c10d::symmetric_memory::get_mempool_allocator) .def_property_readonly("rank", &SymmetricMemory::get_rank) .def_property_readonly("world_size", &SymmetricMemory::get_world_size) .def_property_readonly( diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 2831a4416de9..97aec6a87d3b 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -266,6 +266,61 @@ TORCH_API bool has_multicast_support( return allocator->has_multicast_support(device_idx); } } + +// MemPool Support + +// A map from device type to allocator for MemPool. +// TODO: Consolidate with `AllocatorMap` above. +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class MemPoolAllocatorMap { + public: + MemPoolAllocatorMap(const MemPoolAllocatorMap&) = delete; + MemPoolAllocatorMap& operator=(const MemPoolAllocatorMap&) = delete; + static MemPoolAllocatorMap& get() { + static MemPoolAllocatorMap instance; + return instance; + } + + // Register allocator for MemPool given device type + void register_mempool_allocator( + c10::DeviceType device_type, + std::shared_ptr allocator) { + mempool_allocators_[device_type] = std::move(allocator); + } + + // Get allocator for MemPool given device + std::shared_ptr get_mempool_allocator(c10::Device device) { + auto it = mempool_allocators_.find(device.type()); + if (it == mempool_allocators_.end()) { + TORCH_CHECK( + false, + "SymmetricMemory MemPool did not find backend for device type ", + device.type()); + } + return it->second; + } + + private: + MemPoolAllocatorMap() = default; + + std::unordered_map> + mempool_allocators_; +}; + +// Register allocator for MemPool given device type +C10_EXPORT void register_mempool_allocator( + c10::DeviceType device_type, + std::shared_ptr allocator) { + return MemPoolAllocatorMap::get().register_mempool_allocator( + device_type, std::move(allocator)); +} + +// Get allocator for MemPool given device +TORCH_API std::shared_ptr get_mempool_allocator( + c10::Device device) { + return MemPoolAllocatorMap::get().get_mempool_allocator(device); +} + } // namespace c10d::symmetric_memory namespace { diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp index c2828de04c9b..82586239a231 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp @@ -184,4 +184,11 @@ TORCH_API void set_backend(const std::string& name); TORCH_API std::optional get_backend(c10::Device device); +C10_EXPORT void register_mempool_allocator( + c10::DeviceType device_type, + std::shared_ptr allocator); + +TORCH_API std::shared_ptr get_mempool_allocator( + c10::Device device); + } // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp b/torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp new file mode 100644 index 000000000000..bfbe02bd6f86 --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp @@ -0,0 +1,39 @@ +#include +#include + +namespace { +using namespace c10d::symmetric_memory; + +// Alloc functor for MemPool +void* cuda_symm_alloc(size_t size, int device, void* stream) { + static auto allocator = get_allocator(c10::DeviceType::CUDA); + TORCH_CHECK( + allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported"); + // Note: this alloc functor works for the NVSHMEM and NCCL backends only, + // because only these backends takes `nullopt` for the `group` argument which + // is not given by MemPool's invocation (actually these two backends requires + // it to be `nullopt`). + return allocator->alloc(size, device, /*group_name=*/std::nullopt); +} + +// Free functor for MemPool +void cuda_symm_free(void* ptr, size_t size, int device, void* stream) { + static auto allocator = get_allocator(c10::DeviceType::CUDA); + TORCH_CHECK( + allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported"); + allocator->free(ptr); +} + +// Register allocator for CUDA MemPool +struct RegisterCUDAMemPoolAllocator { + RegisterCUDAMemPoolAllocator() { + std::shared_ptr allocator = + torch::cuda::CUDAPluggableAllocator::createCustomAllocator( + cuda_symm_alloc, cuda_symm_free); + register_mempool_allocator(c10::DeviceType::CUDA, allocator); + } +}; + +static RegisterCUDAMemPoolAllocator register_cuda_mempool_allocator_; + +} // namespace diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 1622ebc66a01..43c2959fdd8d 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1782,4 +1782,14 @@ def get_backend(device: _device) -> str | None: return _SymmetricMemory.get_backend(torch.device(device)) +def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def] + r""" + Get the MemPool allocator for symmetric memory for a given device. + Args: + device (class:`torch.device` or str): the device for which to get the + MemPool allocator. + """ + return _SymmetricMemory.get_mempool_allocator(torch.device(device)) + + __all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]