From 77830d509fcae41be37f5b3a2fa05faabc778e29 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 18 Jun 2024 18:11:43 +0000 Subject: [PATCH] Revert "Introduce a prototype for SymmetricMemory (#128582)" This reverts commit 7a39755da28d5a109bf0c37f72b364d3a83137b1. Reverted https://github.com/pytorch/pytorch/pull/128582 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128582#issuecomment-2176685232)) --- .lintrunner.toml | 1 - BUILD.bazel | 1 - build_variables.bzl | 2 - c10/cuda/driver_api.h | 19 +- caffe2/CMakeLists.txt | 1 - test/distributed/test_symmetric_memory.py | 156 ----- torch/_C/_distributed_c10d.pyi | 30 - .../distributed/c10d/CUDASymmetricMemory.cu | 539 ------------------ .../distributed/c10d/CUDASymmetricMemory.cuh | 109 ---- .../distributed/c10d/ProcessGroupCudaP2P.hpp | 1 - .../csrc/distributed/c10d/SymmetricMemory.cpp | 189 ------ .../csrc/distributed/c10d/SymmetricMemory.hpp | 152 ----- torch/csrc/distributed/c10d/init.cpp | 39 -- .../csrc/distributed/c10d/intra_node_comm.cpp | 99 +++- .../csrc/distributed/c10d/intra_node_comm.cu | 18 +- .../csrc/distributed/c10d/intra_node_comm.hpp | 9 +- 16 files changed, 111 insertions(+), 1254 deletions(-) delete mode 100644 test/distributed/test_symmetric_memory.py delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cu delete mode 100644 torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.cpp delete mode 100644 torch/csrc/distributed/c10d/SymmetricMemory.hpp diff --git a/.lintrunner.toml b/.lintrunner.toml index dc9f9ddd46c7..a7bbdc884415 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -68,7 +68,6 @@ include_patterns = [ 'aten/src/ATen/native/cudnn/*.cpp', 'c10/**/*.h', 'c10/**/*.cpp', - 'distributed/c10d/*SymmetricMemory.*', 'torch/csrc/**/*.h', 'torch/csrc/**/*.hpp', 'torch/csrc/**/*.cpp', diff --git a/BUILD.bazel b/BUILD.bazel index c563c52d861e..10c065f5084c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,7 +744,6 @@ cc_library( "torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 793b611a0a6f..ceb28707897e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -501,7 +501,6 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/ProcessGroupMPI.cpp", "torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp", "torch/csrc/distributed/c10d/Store.cpp", - "torch/csrc/distributed/c10d/SymmetricMemory.cpp", "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", @@ -685,7 +684,6 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index cbbdf16823ec..43bcbd1d70ba 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -18,17 +18,14 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ +#define C10_LIBCUDA_DRIVER_API(_) \ + _(cuMemAddressReserve) \ + _(cuMemRelease) \ + _(cuMemMap) \ + _(cuMemAddressFree) \ + _(cuMemSetAccess) \ + _(cuMemUnmap) \ + _(cuMemCreate) \ _(cuGetErrorString) #define C10_NVML_DRIVER_API(_) \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8426741609fe..89c31fab1134 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -560,7 +560,6 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp - ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py deleted file mode 100644 index a768e059044f..000000000000 --- a/test/distributed/test_symmetric_memory.py +++ /dev/null @@ -1,156 +0,0 @@ -# Owner(s): ["module: c10d"] - -import torch - -import torch.distributed as dist -from torch._C._distributed_c10d import _SymmetricMemory -from torch.distributed.distributed_c10d import _get_process_group_store - -from torch.testing._internal.common_distributed import ( - MultiProcessTestCase, - skip_if_lt_x_gpu, -) -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - run_tests, - skip_but_pass_in_sandcastle_if, - skipIfRocm, -) - - -def requires_cuda_p2p_access(): - cuda_p2p_access_available = ( - torch.cuda.is_available() and torch.cuda.device_count() >= 2 - ) - num_devices = torch.cuda.device_count() - for i in range(num_devices - 1): - for j in range(i + 1, num_devices): - if not torch.cuda.can_device_access_peer(i, j): - cuda_p2p_access_available = False - break - if not cuda_p2p_access_available: - break - - return skip_but_pass_in_sandcastle_if( - not cuda_p2p_access_available, - "cuda p2p access is not available", - ) - - -@instantiate_parametrized_tests -@requires_cuda_p2p_access() -class SymmetricMemoryTest(MultiProcessTestCase): - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - @property - def world_size(self) -> int: - return 2 - - @property - def device(self) -> torch.device: - return torch.device(f"cuda:{self.rank}") - - def _init_process(self): - torch.cuda.set_device(self.device) - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="nccl", - world_size=self.world_size, - rank=self.rank, - store=store, - ) - _SymmetricMemory.set_group_info( - "0", - self.rank, - self.world_size, - _get_process_group_store(dist.GroupMember.WORLD), - ) - - def _verify_symmetric_memory(self, symm_mem): - self.assertEqual(symm_mem.world_size, 2) - - buf = symm_mem.get_buffer(0, (64, 64), torch.float32) - if symm_mem.rank == 0: - symm_mem.wait_signal(src_rank=1) - self.assertTrue(buf.eq(42).all()) - else: - buf.fill_(42) - symm_mem.put_signal(dst_rank=0) - - symm_mem.barrier() - - if symm_mem.rank == 0: - symm_mem.barrier() - self.assertTrue(buf.eq(43).all()) - else: - buf.fill_(43) - symm_mem.barrier() - - symm_mem.barrier() - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name) - - t = torch.empty(shape, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - _SymmetricMemory.rendezvous(t) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - symm_mem = _SymmetricMemory.rendezvous(t) - - del t - self._verify_symmetric_memory(symm_mem) - - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_empty_strided_p2p_persistent(self) -> None: - self._init_process() - - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - alloc_id = 42 # Persistent allocation - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name, alloc_id) - - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - data_ptr = t.data_ptr() - - # Verify that persistent allocation would fail if there's an active - # allocation with the same alloc_id. - with self.assertRaises(RuntimeError): - _SymmetricMemory.empty_strided_p2p(*alloc_args) - - # Verify that persistent allocation would succeed in lieu of activate - # allocations with the same alloc_id, and the returned tensor would - # have the same data pointer. - del t - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) - self.assertEqual(t.data_ptr(), data_ptr) - - # Verify that get_symmetric_memory would fail if called before - # rendezvous. - with self.assertRaises(RuntimeError): - _SymmetricMemory.get_symmetric_memory(t) - - symm_mem_0 = _SymmetricMemory.rendezvous(t) - symm_mem_1 = _SymmetricMemory.get_symmetric_memory(t) - self.assertEqual(id(symm_mem_0), id(symm_mem_1)) - - self._verify_symmetric_memory(symm_mem_0) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 0095b5af434b..cffbf22219c8 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -637,33 +637,3 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... - -class _SymmetricMemory: - @staticmethod - def set_group_info( - group_name: str, rank: int, world_size: int, store: Store - ) -> None: ... - @staticmethod - def empty_strided_p2p( - size: torch.types._size, - stride: torch.types._size, - dtype: torch.dtype, - device: torch.device, - group_name: str, - ) -> torch.Tensor: ... - @property - def rank(self) -> int: ... - @property - def world_size(self) -> int: ... - @staticmethod - def rendezvous(tensor: torch.Tensor) -> _SymmetricMemory: ... - def get_buffer( - self, - rank: int, - sizes: torch.Size, - dtype: torch.dtype, - storage_offset: Optional[int] = 0, - ) -> torch.Tensor: ... - def barrier(self, channel: int = 0) -> None: ... - def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... - def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu deleted file mode 100644 index f27db85f7ff8..000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ /dev/null @@ -1,539 +0,0 @@ -#include - -#include -#include -#include -#include - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include -#endif - -#include -#include - -namespace { - -constexpr size_t signal_pad_size = 2048; -const std::string store_comm_prefix = "CUDASymmetricMemory"; - -static size_t store_comm_seq_id = 0; - -template -std::vector store_all_gather( - const c10::intrusive_ptr& store, - int rank, - int world_size, - T val) { - static_assert(std::is_trivially_copyable_v); - - std::vector peer_keys; - for (int r = 0; r < world_size; ++r) { - std::ostringstream oss; - oss << store_comm_prefix << "/" << store_comm_seq_id << "/" << r; - peer_keys.push_back(oss.str()); - } - ++store_comm_seq_id; - - { - std::vector payload( - reinterpret_cast(&val), - reinterpret_cast(&val) + sizeof(T)); - store->set(peer_keys[rank], payload); - } - - std::vector peer_vals; - for (int r = 0; r < world_size; ++r) { - if (r == rank) { - peer_vals.push_back(val); - continue; - } - store->wait({peer_keys[r]}); - auto payload = store->get(peer_keys[r]); - TORCH_CHECK(payload.size() == sizeof(T)); - T peer_val{}; - std::memcpy(&peer_val, payload.data(), sizeof(T)); - peer_vals.push_back(peer_val); - } - return peer_vals; -} - -void store_barrier( - const c10::intrusive_ptr& store, - int rank, - int world_size) { - store_all_gather(store, rank, world_size, 0); -} - -int import_remote_fd(int pid, int fd) { -#if defined(SYS_pidfd_open) and defined(SYS_pidfd_getfd) - int pidfd = syscall(SYS_pidfd_open, pid, 0); - return syscall(SYS_pidfd_getfd, pidfd, fd, 0); -#else - TORCH_CHECK( - false, - "CUDASymmetricMemory requires pidfd_open ", - "and pidfd_getfd support"); -#endif -} - -void map_block( - void** ptr, - c10d::symmetric_memory::HandleType handle, - size_t size, - int device_idx) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - auto dev_ptr = reinterpret_cast(ptr); - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemAddressReserve_(dev_ptr, size, 0ULL, 0, 0ULL)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemMap_(*dev_ptr, size, 0, handle, 0ULL)); - - CUmemAccessDesc desc; - desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - desc.location.id = static_cast(device_idx); - desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1)); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -CUDASymmetricMemory::CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size) - : handles_(std::move(handles)), - block_size_(block_size), - buffers_(std::move(buffers)), - signal_pads_(std::move(signal_pads)), - buffer_size_(buffer_size), - local_device_idx_(local_device_idx), - rank_(rank), - world_size_(world_size) { - const size_t arr_size = sizeof(void*) * world_size_; - buffers_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - signal_pads_dev_ = reinterpret_cast( - c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); - - c10::cuda::CUDAGuard guard(local_device_idx); - AT_CUDA_CHECK(cudaMemcpy( - buffers_dev_, buffers_.data(), arr_size, cudaMemcpyHostToDevice)); - AT_CUDA_CHECK(cudaMemcpy( - signal_pads_dev_, signal_pads_.data(), arr_size, cudaMemcpyHostToDevice)); -} - -CUDASymmetricMemory::~CUDASymmetricMemory() { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - c10::cuda::CUDAGuard guard(local_device_idx_); - C10_CUDA_CHECK(cudaDeviceSynchronize()); - - auto driver_api = c10::cuda::DriverAPI::get(); - for (int r = 0; r < world_size_; ++r) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(buffers_[r]), block_size_)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handles_[r])); - } - c10::cuda::CUDACachingAllocator::raw_delete(buffers_dev_); - c10::cuda::CUDACachingAllocator::raw_delete(signal_pads_dev_); -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -std::vector CUDASymmetricMemory::get_buffer_ptrs() { - return buffers_; -} - -std::vector CUDASymmetricMemory::get_signal_pad_ptrs() { - return signal_pads_; -} - -void** CUDASymmetricMemory::get_buffer_ptrs_dev() { - return buffers_dev_; -} - -void** CUDASymmetricMemory::get_signal_pad_ptrs_dev() { - return signal_pads_dev_; -} - -size_t CUDASymmetricMemory::get_buffer_size() { - return buffer_size_; -} - -size_t CUDASymmetricMemory::get_signal_pad_size() { - return signal_pad_size; -} - -at::Tensor CUDASymmetricMemory::get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) { - const auto numel = - std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); - const auto element_size = c10::elementSize(dtype); - const auto req_size = (numel + storage_offset) * element_size; - TORCH_CHECK( - req_size <= buffer_size_, - "CUDASymmetricMemory::get_buffer: the requested size (", - req_size, - " bytes) exceeds the allocated size (", - buffer_size_, - " bytes)"); - auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storage_offset) - .options(options) - .target_device(device) - .make_tensor(); -} - -void check_channel(int channel, int world_size) { - TORCH_CHECK( - channel >= 0, - "channel for barrier(), put_signal() and wait_signal() ", - "must be greater than 0 (got ", - channel, - ")"); - const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; - TORCH_CHECK( - static_cast(channel) < num_channels, - "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", - num_channels - 1, - " (got ", - channel, - ")"); -} - -__device__ __forceinline__ void release_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); -#endif -} - -__device__ __forceinline__ void acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); -#endif -} - -static __global__ void barrier_kernel( - uint32_t** signal_pads, - int channel, - int rank, - int world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + world_size * channel + rank); - acquire_signal(signal_pads[rank] + world_size * channel + target_rank); - } -} - -void CUDASymmetricMemory::barrier(int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void put_signal_kernel( - uint32_t** signal_pads, - int dst_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - release_signal(signal_pads[dst_rank] + world_size * channel + rank); - } -} - -void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - dst_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -static __global__ void wait_signal_kernel( - uint32_t** signal_pads, - int src_rank, - int channel, - int rank, - int world_size) { - if (threadIdx.x == 0) { - acquire_signal(signal_pads[rank] + world_size * channel + src_rank); - } - __threadfence_system(); -} - -void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { - check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); - wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(signal_pads_dev_), - src_rank, - channel, - rank_, - world_size_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -int CUDASymmetricMemory::get_rank() { - return rank_; -} - -int CUDASymmetricMemory::get_world_size() { - return world_size_; -} - -void* CUDASymmetricMemoryAllocator::alloc( - size_t size, - int device_idx, - const std::string& group_name) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto driver_api = c10::cuda::DriverAPI::get(); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - // NOLINTNEXTLINE(bugprone-signed-char-misuse) - prop.location.id = device_idx; - prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; - - size_t signal_pad_offset = at::round_up(size, 16UL); - size_t block_size = signal_pad_offset + signal_pad_size; - - size_t granularity; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemGetAllocationGranularity_( - &granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); - block_size = at::round_up(block_size, granularity); - - HandleType handle; - C10_CUDA_DRIVER_CHECK( - driver_api->cuMemCreate_(&handle, block_size, &prop, 0)); - - void* ptr = nullptr; - map_block(&ptr, handle, block_size, device_idx); - - c10::cuda::CUDAGuard guard(device_idx); - AT_CUDA_CHECK(cudaMemset(ptr, 0, block_size)); - - auto block = c10::make_intrusive( - handle, device_idx, block_size, size, signal_pad_offset, group_name); - { - std::unique_lock lock(mutex_); - ptr_to_block_.emplace(ptr, std::move(block)); - } - return ptr; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -void CUDASymmetricMemoryAllocator::free(void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - if (block == nullptr) { - return; - } - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. - if (block->symm_mem == nullptr) { - auto driver_api = c10::cuda::DriverAPI::get(); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemUnmap_( - reinterpret_cast(ptr), block->block_size)); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(block->handle)); - } - { - std::unique_lock lock(mutex_); - ptr_to_block_.erase(ptr); - } -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -size_t CUDASymmetricMemoryAllocator::get_alloc_size(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::get_alloc_size: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->buffer_size; -} - -struct RendezvousRequest { - int device_idx; - int block_fd; - int pid; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; -}; - -void validate_rendezvous_requests( - const std::vector reqs, - int world_size) { - TORCH_CHECK(reqs.size() == (size_t)world_size); - - std::unordered_set device_indices; - device_indices.reserve(world_size); - for (auto req : reqs) { - device_indices.insert(req.device_idx); - } - if (device_indices.size() < (size_t)world_size) { - TORCH_CHECK( - false, - "CUDASymmetricMemoryAllocator::rendezvous: ", - "detected allocations from overlapping devices ", - "from different ranks."); - } - - for (int r = 1; r < world_size; ++r) { - TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); - TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); - TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); - } -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( - void* ptr) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::rendezvous: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - - if (block->symm_mem != nullptr) { - return block->symm_mem; - } - - auto group_info = get_group_info(block->group_name); - auto driver_api = c10::cuda::DriverAPI::get(); - int block_fd; - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &block_fd, block->handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); - - auto local_req = RendezvousRequest{ - .device_idx = block->device_idx, - .block_fd = block_fd, - .pid = getpid(), - .block_size = block->block_size, - .buffer_size = block->buffer_size, - .signal_pad_offset = block->signal_pad_offset}; - auto reqs = store_all_gather( - group_info.store, group_info.rank, group_info.world_size, local_req); - validate_rendezvous_requests(reqs, group_info.world_size); - - std::vector handles(group_info.world_size); - std::vector buffers(group_info.world_size, nullptr); - std::vector signal_pads(group_info.world_size, nullptr); - for (int r = 0; r < group_info.world_size; ++r) { - if (r == group_info.rank) { - handles[r] = block->handle; - buffers[r] = ptr; - signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); - continue; - } - int imported_fd = import_remote_fd(reqs[r].pid, reqs[r].block_fd); - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)(uintptr_t)imported_fd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - map_block(&buffers[r], handles[r], block->block_size, block->device_idx); - signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); - close(imported_fd); - } - store_barrier(group_info.store, group_info.rank, group_info.world_size); - close(block_fd); - - // Initializing CUDASymmetricMemory with an allocation transfers its - // ownership to the CUDASymmetricMemory object. So that outstanding - // references to the CUDASymmetricMemory object can keep the allocation - // alive. - block->symm_mem = c10::make_intrusive( - std::move(handles), - block->block_size, - std::move(buffers), - std::move(signal_pads), - block->buffer_size, - block->device_idx, - group_info.rank, - group_info.world_size); - return block->symm_mem; -#else - TORCH_CHECK( - false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED"); -#endif -} - -bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { - auto block = find_block(ptr); - TORCH_CHECK( - block != nullptr, - "CUDASymmetricMemoryAllocator::is_rendezvous_completed: input must be allocated ", - "via CUDASymmetricMemoryAllocator::alloc"); - return block->symm_mem != nullptr; -} - -c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { - std::shared_lock lock(mutex_); - auto it = ptr_to_block_.find(ptr); - if (it == ptr_to_block_.end()) { - return nullptr; - } - return it->second; -} - -struct RegisterCUDASymmetricMemoryAllocator { - RegisterCUDASymmetricMemoryAllocator() { - register_allocator( - c10::DeviceType::CUDA, - c10::make_intrusive()); - } -}; - -static RegisterCUDASymmetricMemoryAllocator register_allocator_; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh deleted file mode 100644 index 0e0e40a6bd09..000000000000 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cuh +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace c10d { -namespace symmetric_memory { - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -using HandleType = CUmemGenericAllocationHandle; -#else -using HandleType = void*; -#endif - -class CUDASymmetricMemory : public SymmetricMemory { - public: - CUDASymmetricMemory( - std::vector handles, - size_t block_size, - std::vector buffers, - std::vector signal_pads, - size_t buffer_size, - int local_device_idx, - int rank, - int world_size); - - ~CUDASymmetricMemory() override; - - std::vector get_buffer_ptrs() override; - std::vector get_signal_pad_ptrs() override; - void** get_buffer_ptrs_dev() override; - void** get_signal_pad_ptrs_dev() override; - size_t get_buffer_size() override; - size_t get_signal_pad_size() override; - - at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) override; - - void barrier(int channel) override; - void put_signal(int dst_rank, int channel) override; - void wait_signal(int src_rank, int channel) override; - - int get_rank() override; - int get_world_size() override; - - private: - std::vector handles_; - size_t block_size_; - std::vector buffers_; - std::vector signal_pads_; - size_t buffer_size_; - int local_device_idx_; - int rank_; - int world_size_; - void** buffers_dev_; - void** signal_pads_dev_; - std::optional> finalizer_; -}; - -struct Block : public c10::intrusive_ptr_target { - HandleType handle; - int device_idx; - size_t block_size; - size_t buffer_size; - size_t signal_pad_offset; - std::string group_name; - c10::intrusive_ptr symm_mem = nullptr; - - Block( - HandleType handle, - int device_idx, - size_t block_size, - size_t buffer_size, - size_t signal_pad_offset, - const std::string& group_name) - : handle(handle), - device_idx(device_idx), - block_size(block_size), - buffer_size(buffer_size), - signal_pad_offset(signal_pad_offset), - group_name(group_name), - symm_mem(nullptr) {} -}; - -class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { - public: - void* alloc( - size_t size, - int device_idx, - const std::string& group_name) override; - - void free(void *ptr) override; - size_t get_alloc_size(void* ptr) override; - c10::intrusive_ptr rendezvous(void* ptr) override; - bool is_rendezvous_completed(void* ptr) override; - - private: - c10::intrusive_ptr find_block(void* ptr); - - std::shared_mutex mutex_; - std::unordered_map> ptr_to_block_; -}; - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp index 7c41414c4e4e..cff4ad09b706 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp @@ -10,7 +10,6 @@ constexpr auto kProcessGroupCudaP2PDefaultTimeout = namespace c10d { -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API ProcessGroupCudaP2P : public Backend { public: struct Options : Backend::Options { diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp deleted file mode 100644 index b3d9f31bb034..000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include - -namespace { - -using namespace c10d::symmetric_memory; - -class AllocatorMap { - public: - static AllocatorMap& get() { - static AllocatorMap instance; - return instance; - } - - void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - map_[device_type] = std::move(allocator); - } - - c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - auto it = map_.find(device_type); - TORCH_CHECK( - it != map_.end(), - "SymmetricMemory does not support device type ", - device_type); - return it->second; - } - - ~AllocatorMap() { - for (auto& it : map_) { - it.second.release(); - } - } - - private: - AllocatorMap() = default; - AllocatorMap(const AllocatorMap&) = delete; - AllocatorMap& operator=(const AllocatorMap&) = delete; - - std::unordered_map< - c10::DeviceType, - c10::intrusive_ptr> - map_; -}; - -static std::unordered_map group_info_map{}; - -// Data structures for tracking persistent allocations -static std::unordered_map alloc_id_to_dev_ptr{}; -static std::unordered_map> - alloc_id_to_storage{}; - -static at::Tensor empty_strided_p2p_persistent( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - uint64_t alloc_id) { - // Make the allocation fails if a previous allocation with the same alloc_id - // is still active. - auto storage = alloc_id_to_storage.find(alloc_id); - if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) { - TORCH_CHECK( - false, - "SymmetricMemory::empty_strided_p2p_persistent: ", - "can not allocate with alloc_id == ", - alloc_id, - " because a previous allocation with the same alloc_id " - "is still active."); - } - - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = nullptr; - if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) { - dev_ptr = alloc_id_to_dev_ptr[alloc_id]; - TORCH_CHECK( - alloc_size == allocator->get_alloc_size(dev_ptr), - "SymmetricMemory::empty_strided_p2p_persistent: ", - "requested allocation size (", - alloc_size, - ") is different from the size of a previous allocation ", - "with the same alloc_id ", - allocator->get_alloc_size(dev_ptr)); - } else { - dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - alloc_id_to_dev_ptr[alloc_id] = dev_ptr; - } - - auto options = at::TensorOptions().dtype(dtype).device(device); - auto allocated = at::from_blob(dev_ptr, size, stride, options); - - // Track the allocation's activeness - alloc_id_to_storage.erase(alloc_id); - alloc_id_to_storage.emplace( - alloc_id, allocated.storage().getWeakStorageImpl()); - return allocated; -} - -} // namespace - -namespace c10d { -namespace symmetric_memory { - -void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator) { - return AllocatorMap::get().register_allocator( - device_type, std::move(allocator)); -} - -c10::intrusive_ptr get_allocator( - c10::DeviceType device_type) { - return AllocatorMap::get().get_allocator(device_type); -} - -void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store) { - TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end()); - GroupInfo group_info; - group_info.rank = rank; - group_info.world_size = world_size; - group_info.store = std::move(store); - group_info_map.emplace(group_name, std::move(group_info)); -} - -const GroupInfo& get_group_info(const std::string& group_name) { - TORCH_CHECK( - group_info_map.find(group_name) != group_info_map.end(), - "get_group_info: no group info associated with the group name ", - group_name); - return group_info_map[group_name]; -} - -at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id) { - if (alloc_id.has_value()) { - return empty_strided_p2p_persistent( - size, stride, dtype, device, group_name, *alloc_id); - } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); - const size_t element_size = c10::elementSize(dtype); - const size_t alloc_size = numel * element_size; - - auto allocator = get_allocator(device.type()); - void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name); - - auto options = at::TensorOptions().dtype(dtype).device(device); - return at::from_blob( - dev_ptr, - size, - stride, - [allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); }, - options); -} - -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.data_ptr()); -} - -c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor) { - auto allocator = get_allocator(tensor.device().type()); - TORCH_CHECK( - allocator->is_rendezvous_completed(tensor.data_ptr()), - "SymmetricMemory: must invoke rendezvous on a tensor ", - "before calling get_symmetric_memory on it"); - return allocator->rendezvous(tensor.data_ptr()); -} - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp deleted file mode 100644 index 344b86ea5c7e..000000000000 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once - -#include -#include - -namespace c10d { -namespace symmetric_memory { - -// SymmetricMemory represents symmetric allocations across a group of devices. -// The allocations represented by a SymmetricMemory object are accessible by -// all devices in the group. The class can be used for op-level custom -// communication patterns (via the get_buffer APIs and the synchronization -// primitives), as well as custom communication kernels (via the buffer and -// signal_pad device pointers). -// -// To acquire a SymmetricMemory object, each rank first allocates -// identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes -// SymmetricMemoryAllocator::rendezvous() on the memory to establish the -// association across peer buffers. The rendezvous is a one-time process, and -// the mapping between a local memory memory and the associated SymmetricMemory -// object is unique. -// -// NOTE [symmetric memory signal pad] -// Signal pads are P2P-accessible memory regions designated for -// synchronization. SymmetricMemory offers built-in synchronization primitives -// such as barriers, put_signal, and wait_signal, which are all based on signal -// pads. Users may utilize signal pads for their own synchronization logic, -// provided that the signal pads remain zero-filled following successful -// synchronization. -// -// NOTE [symmetric memory synchronization channel] -// Synchronization channels allow users to use a single SymmetricMemory object -// to perform isolated synchronizations on different streams. For example, -// consider the case in which two barriers are issued on two streams for -// different purposes. Without the concept of channels, we cannot guarantee the -// correctness of the barriers since signals issued from barrier on stream A -// can be received by the barrier on stream B. By specifying different channels -// for these two barriers, they can operate correctly in parallel. -class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemory() {} - - virtual std::vector get_buffer_ptrs() = 0; - virtual std::vector get_signal_pad_ptrs() = 0; - - // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer - // to a device array of size world_size, containing buffer pointers and - // signal pad pointers, respectively. - virtual void** get_buffer_ptrs_dev() = 0; - virtual void** get_signal_pad_ptrs_dev() = 0; - virtual size_t get_buffer_size() = 0; - virtual size_t get_signal_pad_size() = 0; - - virtual at::Tensor get_buffer( - int rank, - c10::IntArrayRef sizes, - c10::ScalarType dtype, - int64_t storage_offset) = 0; - - virtual void barrier(int channel) = 0; - virtual void put_signal(int dst_rank, int channel) = 0; - virtual void wait_signal(int src_rank, int channel) = 0; - - virtual int get_rank() = 0; - virtual int get_world_size() = 0; -}; - -class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { - public: - virtual ~SymmetricMemoryAllocator(){}; - - virtual void* alloc( - size_t size, - int device_idx, - const std::string& group_name) = 0; - - virtual void free(void* ptr) = 0; - virtual size_t get_alloc_size(void* ptr) = 0; - virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; - virtual bool is_rendezvous_completed(void* ptr) = 0; -}; - -C10_EXPORT void register_allocator( - c10::DeviceType device_type, - c10::intrusive_ptr allocator); - -C10_EXPORT c10::intrusive_ptr get_allocator( - c10::DeviceType device_type); - -// Set a store for rendezvousing symmetric allocations on a group of devices -// identified by `group_name`. The concept of groups is logical; users can -// utilize predefined groups (e.g., a group of device identified by a -// ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator -// backends might employ a more efficient communication channel for the actual -// rendezvous process and only use the store for bootstrapping purposes. -TORCH_API void set_group_info( - const std::string& group_name, - int rank, - int world_size, - c10::intrusive_ptr store); - -struct GroupInfo { - int rank; - int world_size; - c10::intrusive_ptr store; -}; - -C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); - -// Identical to empty_strided, but allows symmetric memory access to be -// established for the allocated tensor via SymmetricMemory::rendezvous(). This -// function itself is not a collective operation. It invokes -// SymmetricMemoryAllocator::alloc() for the requested device under the hood. -// -// NOTE [symmetric memory persistent allocation] -// If an `alloc_id` is supplied, empty_strided_p2p will perform persistent -// allocation. This makes the function cache allocated memory and ensure that -// invocations with the same `alloc_id` receive tensors backed by the same -// memory address. For safety, if a previous persistent allocation is still -// active (i.e., the storage of the returned tensor is still alive), persistent -// allocations with the same `alloc_id` will fail. This determinism coupled -// with memory planning of communication buffers (e.g., by Inductor) allows -// communication algorithms to reliably reuse previously established remote -// memory access. -TORCH_API at::Tensor empty_strided_p2p( - c10::IntArrayRef size, - c10::IntArrayRef stride, - c10::ScalarType dtype, - c10::Device device, - const std::string& group_name, - std::optional alloc_id); - -// Establishes symmetric memory access on tensors allocated via -// empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a -// one-time process, and the mapping between a local memory region and the -// associated SymmetricMemory object is unique. Subsequent calls to -// rendezvous() with the same tensor, or tensors allocated with -// empty_strided_p2p_persistent() using the same alloc_id, will receive the -// cached SymmetricMemory object. -// -// The function has a collective semantic and must be invoked simultaneously -// from all rendezvous participants. -TORCH_API c10::intrusive_ptr rendezvous( - const at::Tensor& tensor); - -// Returns the SymmetricMemory object associated with the tensor. It can only -// be invoked after rendezvous() but does not need to be invoked collectively. -TORCH_API c10::intrusive_ptr get_symmetric_memory( - const at::Tensor& tensor); - -} // namespace symmetric_memory -} // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index db5778efcf35..6f1b28886b98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include @@ -976,44 +975,6 @@ This class does not support ``__members__`` property.)"); "global_ranks_in_group", &::c10d::DistributedBackendOptions::global_ranks_in_group); - using SymmetricMemory = ::c10d::symmetric_memory::SymmetricMemory; - py::class_>( - module, "_SymmetricMemory") - .def_static("set_group_info", &::c10d::symmetric_memory::set_group_info) - .def_static( - "empty_strided_p2p", - ::c10d::symmetric_memory::empty_strided_p2p, - py::arg("size"), - py::arg("stride"), - py::arg("dtype"), - py::arg("device"), - py::arg("group_name"), - py::arg("alloc_id") = py::none()) - .def_static("rendezvous", &::c10d::symmetric_memory::rendezvous) - .def_static( - "get_symmetric_memory", - &::c10d::symmetric_memory::get_symmetric_memory) - .def_property_readonly("rank", &SymmetricMemory::get_rank) - .def_property_readonly("world_size", &SymmetricMemory::get_world_size) - .def( - "get_buffer", - &SymmetricMemory::get_buffer, - py::arg("rank"), - py::arg("sizes"), - py::arg("dtype"), - py::arg("storage_offset") = 0) - .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) - .def( - "put_signal", - &SymmetricMemory::put_signal, - py::arg("dst_rank"), - py::arg("channel") = 0) - .def( - "wait_signal", - &SymmetricMemory::wait_signal, - py::arg("src_rank"), - py::arg("channel") = 0); - auto store = py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>( module, diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 9d7ba5abf951..85136a91e025 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -218,8 +218,23 @@ IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { return; } - auto allocator = get_allocator(c10::DeviceType::CUDA); - allocator->free(symmetricMemoryPtr_); + // Intentionally releasing resources without synchronizing devices. The + // teardown logic is safe for propoerly sync'd user program. We don't want + // improperly sync'd user program to hang here. + for (size_t r = 0; r < worldSize_; ++r) { + if (r == rank_) { + continue; + } + AT_CUDA_CHECK(cudaIpcCloseMemHandle(p2pStates_[r])); + AT_CUDA_CHECK(cudaIpcCloseMemHandle(buffers_[r])); + } + AT_CUDA_CHECK(cudaFree(p2pStates_[rank_])); + AT_CUDA_CHECK(cudaFree(buffers_[rank_])); + if (topoInfo_ != nullptr) { + AT_CUDA_CHECK(cudaFree(topoInfo_)); + } + AT_CUDA_CHECK(cudaFree(p2pStatesDev_)); + AT_CUDA_CHECK(cudaFree(buffersDev_)); } bool IntraNodeComm::isEnabled() { @@ -329,19 +344,83 @@ bool IntraNodeComm::rendezvous() { // Detect topology Topology topology = detectTopology(nvlMesh, worldSize_); - set_group_info("IntraNodeComm", rank_, worldSize_, store_); - auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = - allocator->alloc(bufferSize_, deviceIdx, "IntraNodeComm"); - symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); - TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); + // Initialize p2p state + auto p2pState = initP2pState(); + + // Allocate buffer + void* buffer = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffer, bufferSize_)); + + // Second handshake: exchange topology and CUDA IPC handles + struct IpcInfo { + NvlMesh nvlMesh; + Topology topology; + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + }; + + // Make p2p state and buffer available for IPC + cudaIpcMemHandle_t p2pStateHandle, bufferHandle; + AT_CUDA_CHECK(cudaIpcGetMemHandle(&p2pStateHandle, p2pState)); + AT_CUDA_CHECK(cudaIpcGetMemHandle(&bufferHandle, buffer)); + + IpcInfo ipcInfo{ + .nvlMesh = nvlMesh, + .topology = topology, + .p2pStateHandle = p2pStateHandle, + .bufferHandle = bufferHandle}; + + auto peerIpcInfos = + storeAllGather(store_, "handshake-1", rank_, worldSize_, ipcInfo); + + for (const auto& info : peerIpcInfos) { + if (!isSame(info.nvlMesh, peerIpcInfos.front().nvlMesh) || + info.topology != peerIpcInfos.front().topology) { + LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " + "participants are observing different topologies (" + << int(info.topology) << " and " << int(topology) << ")"; + AT_CUDA_CHECK(cudaFree(p2pState)); + AT_CUDA_CHECK(cudaFree(buffer)); + return false; + } + } + + std::array p2pStates = {}, buffers = {}; + for (size_t r = 0; r < peerIpcInfos.size(); ++r) { + if (r == rank_) { + p2pStates[r] = p2pState; + buffers[r] = buffer; + } else { + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &p2pStates[r], + peerIpcInfos[r].p2pStateHandle, + cudaIpcMemLazyEnablePeerAccess)); + AT_CUDA_CHECK(cudaIpcOpenMemHandle( + &buffers[r], + peerIpcInfos[r].bufferHandle, + cudaIpcMemLazyEnablePeerAccess)); + } + } + void* p2pStatesDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&p2pStatesDev, sizeof(p2pStates))); + AT_CUDA_CHECK(cudaMemcpy( + p2pStatesDev, + p2pStates.data(), + sizeof(p2pStates), + cudaMemcpyHostToDevice)); + + void* buffersDev = nullptr; + AT_CUDA_CHECK(cudaMalloc(&buffersDev, sizeof(buffers))); + AT_CUDA_CHECK(cudaMemcpy( + buffersDev, buffers.data(), sizeof(buffers), cudaMemcpyHostToDevice)); void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); isInitialized_ = true; topology_ = topology; - p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); - buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); + std::copy(p2pStates.begin(), p2pStates.end(), p2pStates_.begin()); + std::copy(buffers.begin(), buffers.end(), buffers_.begin()); + p2pStatesDev_ = p2pStatesDev; + buffersDev_ = buffersDev; topoInfo_ = topoInfo; return true; #endif diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index ac751ff7be1e..51fc6252d223 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -132,8 +132,6 @@ struct P2pState { uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; }; -static_assert(sizeof(P2pState) <= kP2pStateSize); - template static __global__ void oneShotAllReduceKernel( at::BFloat16* input, @@ -524,7 +522,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( const bool fuseInputCopy = isAligned && blocks.x < kMaxAllReduceBlocks; if (!fuseInputCopy) { AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -584,7 +582,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -634,7 +632,7 @@ at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( at::cuda::OptionalCUDAGuard guard(input.get_device()); AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs_dev()[rank_], + buffers_[rank_], input.data_ptr(), input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, @@ -757,7 +755,15 @@ at::Tensor IntraNodeComm::getBuffer( const std::vector& sizes, c10::ScalarType dtype, int64_t storageOffset) { - return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); + const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0); + const auto elementSize = c10::elementSize(dtype); + TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_); + auto options = at::TensorOptions().dtype(dtype).device( + at::kCUDA, at::cuda::current_device()); + return at::for_blob(buffers_[rank], sizes) + .storage_offset(storageOffset) + .options(options) + .make_tensor(); } } // namespace intra_node_comm diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index a67df5c34586..5d7e2d426d30 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -4,16 +4,12 @@ #include #include #include -#include #include namespace c10d::intra_node_comm { -using namespace c10d::symmetric_memory; - constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; -constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; using HybridCubeMesh = std::array, kMaxDevices>; @@ -31,7 +27,6 @@ enum class AllReduceAlgo : uint8_t { HCM = 3 }; -// NOTE: this class will be be removed soon in favor of SymmetricMemory class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { public: IntraNodeComm( @@ -102,8 +97,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool isInitialized_ = false; Topology topology_ = Topology::UNKNOWN; - void* symmetricMemoryPtr_ = nullptr; - c10::intrusive_ptr symmetricMemory_ = nullptr; + std::array p2pStates_{}; + std::array buffers_{}; void* p2pStatesDev_{}; void* buffersDev_{}; void* topoInfo_{};