mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/N][SymmMem] Add MemPool allocator and tests (#161471)
(Porting most of #161008) Hooking SymmetricMemory Allocator to MemPool so that user can create symmetric tensors with regular `torch.zeros`, `torch.arange` etc factories. Also so that our ops can have functional variants that create `out` tensors on symmetric memory. To end users, this PR supports a python UI as follows: ``` allocator = symm_mem.get_mempool_allocator(device) mempool = torch.cuda.MemPool(allocator) with torch.cuda.use_mem_pool(mempool): tensor = torch.arange(numel, dtype=dtype, device=device) ``` Added tests for both use cases above. Differential Revision: [](https://our.internmc.facebook.com/intern/diff/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161471 Approved by: https://github.com/ngimel ghstack dependencies: #161470
This commit is contained in:
@ -747,6 +747,7 @@ cc_library(
|
|||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
||||||
],
|
],
|
||||||
)) + torch_sources,
|
)) + torch_sources,
|
||||||
|
@ -755,6 +755,7 @@ libtorch_cuda_distributed_extra_sources = [
|
|||||||
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
|
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
|
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
||||||
|
"torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
|
||||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -581,6 +581,7 @@ if(USE_CUDA)
|
|||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
|
||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
|
||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
|
||||||
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
|
||||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
@ -65,6 +65,58 @@ class NVSHMEMSymmetricMemoryTest(MultiProcContinuousTest):
|
|||||||
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
||||||
symm_mem.rendezvous(out, group=group_name)
|
symm_mem.rendezvous(out, group=group_name)
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
def test_mempool_tensor_factory(self) -> None:
|
||||||
|
"""
|
||||||
|
Test the effectiveness of MemPool on tensor factory ops.
|
||||||
|
"""
|
||||||
|
self._init_device()
|
||||||
|
group_name = dist.group.WORLD.group_name
|
||||||
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
dtype = torch.float
|
||||||
|
numel = 1024
|
||||||
|
src_rank = 0
|
||||||
|
|
||||||
|
allocator = symm_mem.get_mempool_allocator(self.device)
|
||||||
|
mempool = torch.cuda.MemPool(allocator)
|
||||||
|
|
||||||
|
with torch.cuda.use_mem_pool(mempool):
|
||||||
|
if self.rank == src_rank:
|
||||||
|
tensor = torch.arange(numel, dtype=dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
tensor = torch.zeros(numel, dtype=dtype, device=self.device)
|
||||||
|
|
||||||
|
symm_mem.rendezvous(tensor, group=group_name)
|
||||||
|
torch.ops.symm_mem.nvshmem_broadcast(tensor, group_name)
|
||||||
|
self.assertEqual(tensor, torch.arange(numel, dtype=dtype, device=self.device))
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
def test_mempool_compute_ops(self) -> None:
|
||||||
|
"""
|
||||||
|
Apply MemPool context to a compute op that creates input to collective.
|
||||||
|
"""
|
||||||
|
self._init_device()
|
||||||
|
group_name = dist.group.WORLD.group_name
|
||||||
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
dtype = torch.float
|
||||||
|
dim = 1024
|
||||||
|
w = torch.ones(dim, dim, dtype=dtype, device=self.device)
|
||||||
|
x0 = torch.ones(1, dim, dtype=dtype, device=self.device)
|
||||||
|
|
||||||
|
allocator = symm_mem.get_mempool_allocator(self.device)
|
||||||
|
mempool = torch.cuda.MemPool(allocator)
|
||||||
|
|
||||||
|
with torch.cuda.use_mem_pool(mempool):
|
||||||
|
x = x0 + self.rank
|
||||||
|
y = torch.mm(x, w)
|
||||||
|
|
||||||
|
# y should be a symm tensor
|
||||||
|
torch.ops.symm_mem.nvshmem_broadcast(y, group_name)
|
||||||
|
expected = torch.mm(x0, w)
|
||||||
|
self.assertEqual(y, expected)
|
||||||
|
|
||||||
@skipIfRocm
|
@skipIfRocm
|
||||||
def test_nvshmem_put(self) -> None:
|
def test_nvshmem_put(self) -> None:
|
||||||
self._init_device()
|
self._init_device()
|
||||||
|
@ -769,6 +769,8 @@ class _SymmetricMemory:
|
|||||||
def set_backend(name: str) -> None: ...
|
def set_backend(name: str) -> None: ...
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_backend(device: torch.device) -> Optional[str]: ...
|
def get_backend(device: torch.device) -> Optional[str]: ...
|
||||||
|
@staticmethod
|
||||||
|
def get_mempool_allocator(device: torch.device) -> Any: ...
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int: ...
|
def rank(self) -> int: ...
|
||||||
@property
|
@property
|
||||||
|
@ -1128,6 +1128,9 @@ This class does not support ``__members__`` property.)");
|
|||||||
&::c10d::symmetric_memory::has_multicast_support)
|
&::c10d::symmetric_memory::has_multicast_support)
|
||||||
.def_static("set_backend", &::c10d::symmetric_memory::set_backend)
|
.def_static("set_backend", &::c10d::symmetric_memory::set_backend)
|
||||||
.def_static("get_backend", &::c10d::symmetric_memory::get_backend)
|
.def_static("get_backend", &::c10d::symmetric_memory::get_backend)
|
||||||
|
.def_static(
|
||||||
|
"get_mempool_allocator",
|
||||||
|
&::c10d::symmetric_memory::get_mempool_allocator)
|
||||||
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
||||||
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
|
@ -266,6 +266,61 @@ TORCH_API bool has_multicast_support(
|
|||||||
return allocator->has_multicast_support(device_idx);
|
return allocator->has_multicast_support(device_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MemPool Support
|
||||||
|
|
||||||
|
// A map from device type to allocator for MemPool.
|
||||||
|
// TODO: Consolidate with `AllocatorMap` above.
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
|
||||||
|
class MemPoolAllocatorMap {
|
||||||
|
public:
|
||||||
|
MemPoolAllocatorMap(const MemPoolAllocatorMap&) = delete;
|
||||||
|
MemPoolAllocatorMap& operator=(const MemPoolAllocatorMap&) = delete;
|
||||||
|
static MemPoolAllocatorMap& get() {
|
||||||
|
static MemPoolAllocatorMap instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register allocator for MemPool given device type
|
||||||
|
void register_mempool_allocator(
|
||||||
|
c10::DeviceType device_type,
|
||||||
|
std::shared_ptr<c10::Allocator> allocator) {
|
||||||
|
mempool_allocators_[device_type] = std::move(allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get allocator for MemPool given device
|
||||||
|
std::shared_ptr<c10::Allocator> get_mempool_allocator(c10::Device device) {
|
||||||
|
auto it = mempool_allocators_.find(device.type());
|
||||||
|
if (it == mempool_allocators_.end()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"SymmetricMemory MemPool did not find backend for device type ",
|
||||||
|
device.type());
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MemPoolAllocatorMap() = default;
|
||||||
|
|
||||||
|
std::unordered_map<c10::DeviceType, std::shared_ptr<c10::Allocator>>
|
||||||
|
mempool_allocators_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register allocator for MemPool given device type
|
||||||
|
C10_EXPORT void register_mempool_allocator(
|
||||||
|
c10::DeviceType device_type,
|
||||||
|
std::shared_ptr<c10::Allocator> allocator) {
|
||||||
|
return MemPoolAllocatorMap::get().register_mempool_allocator(
|
||||||
|
device_type, std::move(allocator));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get allocator for MemPool given device
|
||||||
|
TORCH_API std::shared_ptr<c10::Allocator> get_mempool_allocator(
|
||||||
|
c10::Device device) {
|
||||||
|
return MemPoolAllocatorMap::get().get_mempool_allocator(device);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10d::symmetric_memory
|
} // namespace c10d::symmetric_memory
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -184,4 +184,11 @@ TORCH_API void set_backend(const std::string& name);
|
|||||||
|
|
||||||
TORCH_API std::optional<std::string> get_backend(c10::Device device);
|
TORCH_API std::optional<std::string> get_backend(c10::Device device);
|
||||||
|
|
||||||
|
C10_EXPORT void register_mempool_allocator(
|
||||||
|
c10::DeviceType device_type,
|
||||||
|
std::shared_ptr<c10::Allocator> allocator);
|
||||||
|
|
||||||
|
TORCH_API std::shared_ptr<c10::Allocator> get_mempool_allocator(
|
||||||
|
c10::Device device);
|
||||||
|
|
||||||
} // namespace c10d::symmetric_memory
|
} // namespace c10d::symmetric_memory
|
||||||
|
39
torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
Normal file
39
torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using namespace c10d::symmetric_memory;
|
||||||
|
|
||||||
|
// Alloc functor for MemPool
|
||||||
|
void* cuda_symm_alloc(size_t size, int device, void* stream) {
|
||||||
|
static auto allocator = get_allocator(c10::DeviceType::CUDA);
|
||||||
|
TORCH_CHECK(
|
||||||
|
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
|
||||||
|
// Note: this alloc functor works for the NVSHMEM and NCCL backends only,
|
||||||
|
// because only these backends takes `nullopt` for the `group` argument which
|
||||||
|
// is not given by MemPool's invocation (actually these two backends requires
|
||||||
|
// it to be `nullopt`).
|
||||||
|
return allocator->alloc(size, device, /*group_name=*/std::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free functor for MemPool
|
||||||
|
void cuda_symm_free(void* ptr, size_t size, int device, void* stream) {
|
||||||
|
static auto allocator = get_allocator(c10::DeviceType::CUDA);
|
||||||
|
TORCH_CHECK(
|
||||||
|
allocator->name() == "NVSHMEM", "Only NVSHMEM backend is supported");
|
||||||
|
allocator->free(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register allocator for CUDA MemPool
|
||||||
|
struct RegisterCUDAMemPoolAllocator {
|
||||||
|
RegisterCUDAMemPoolAllocator() {
|
||||||
|
std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator =
|
||||||
|
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
|
||||||
|
cuda_symm_alloc, cuda_symm_free);
|
||||||
|
register_mempool_allocator(c10::DeviceType::CUDA, allocator);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static RegisterCUDAMemPoolAllocator register_cuda_mempool_allocator_;
|
||||||
|
|
||||||
|
} // namespace
|
@ -1782,4 +1782,14 @@ def get_backend(device: _device) -> str | None:
|
|||||||
return _SymmetricMemory.get_backend(torch.device(device))
|
return _SymmetricMemory.get_backend(torch.device(device))
|
||||||
|
|
||||||
|
|
||||||
|
def get_mempool_allocator(device: _device): # type: ignore[no-untyped-def]
|
||||||
|
r"""
|
||||||
|
Get the MemPool allocator for symmetric memory for a given device.
|
||||||
|
Args:
|
||||||
|
device (class:`torch.device` or str): the device for which to get the
|
||||||
|
MemPool allocator.
|
||||||
|
"""
|
||||||
|
return _SymmetricMemory.get_mempool_allocator(torch.device(device))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]
|
__all__ = ["empty", "rendezvous", "is_nvshmem_available", "set_backend", "get_backend"]
|
||||||
|
Reference in New Issue
Block a user