diff --git a/BUILD.bazel b/BUILD.bazel index d3a32d0ad080..25148db02295 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -744,6 +744,7 @@ cc_library( "torch/csrc/cuda/nccl.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], diff --git a/build_variables.bzl b/build_variables.bzl index 31d8a6a95eaf..e05c94bd83f5 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -688,6 +688,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cpp", "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", + "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp index fde4da864fab..bb201b5c0397 100644 --- a/c10/cuda/driver_api.cpp +++ b/c10/cuda/driver_api.cpp @@ -20,6 +20,12 @@ DriverAPI create_driver_api() { C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY) #undef LOOKUP_LIBCUDA_ENTRY +#define LOOKUP_LIBCUDA_ENTRY(name) \ + r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \ + dlerror(); + C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY) +#undef LOOKUP_LIBCUDA_ENTRY + if (handle_1) { #define LOOKUP_NVML_ENTRY(name) \ r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \ diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index cbbdf16823ec..0a511f7f849a 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -31,6 +31,15 @@ _(cuMemImportFromShareableHandle) \ _(cuGetErrorString) +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) +#define C10_LIBCUDA_DRIVER_API_12030(_) \ + _(cuMulticastAddDevice) \ + _(cuMulticastBindMem) \ + _(cuMulticastCreate) +#else +#define C10_LIBCUDA_DRIVER_API_12030(_) +#endif + #define C10_NVML_DRIVER_API(_) \ _(nvmlInit_v2) \ _(nvmlDeviceGetHandleByPciBusId_v2) \ @@ -43,6 +52,7 @@ namespace c10::cuda { struct DriverAPI { #define CREATE_MEMBER(name) decltype(&name) name##_; C10_LIBCUDA_DRIVER_API(CREATE_MEMBER) + C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER) C10_NVML_DRIVER_API(CREATE_MEMBER) #undef CREATE_MEMBER static DriverAPI* get(); diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index 3410586120d5..c27ad2f10f9e 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist +from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch.distributed._symmetric_memory import ( _fused_all_gather_matmul_fallback, @@ -44,6 +45,17 @@ def requires_cuda_p2p_access(): ) +def requires_multicast_support(): + has_multicast_support = ( + torch.cuda.is_available() + and _SymmetricMemory.has_multicast_support(DeviceType.CUDA) + ) + return skip_but_pass_in_sandcastle_if( + not has_multicast_support, + "multicast support is not available", + ) + + @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmetricMemoryTest(MultiProcessTestCase): @@ -95,7 +107,6 @@ class SymmetricMemoryTest(MultiProcessTestCase): @skipIfRocm @skip_if_lt_x_gpu(2) def test_cuda_nvlink_connectivity_detection(self) -> None: - from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _detect_dma_connectivity connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink") @@ -422,6 +433,73 @@ class SymmetricMemoryTest(MultiProcessTestCase): dist.destroy_process_group() + @skip_if_lt_x_gpu(2) + @requires_multicast_support() + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_multimem_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + + t = _SymmetricMemory.empty_strided_p2p( + size=(16384,), + stride=(1,), + dtype=dtype, + device=self.device, + group_name=group_name, + ).fill_(1) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) + + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + x = t[shift : shift + numel] + + torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name) + self.assertTrue(x.eq(self.world_size).all().item()) + + # Head and tail should not be written + self.assertTrue(t[:shift].eq(1).all().item()) + self.assertTrue(t[shift + numel :].eq(1).all().item()) + dist.destroy_process_group() + + @skip_if_lt_x_gpu(2) + @requires_multicast_support() + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_multimem_one_shot_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + + t = _SymmetricMemory.empty_strided_p2p( + size=(16384,), + stride=(1,), + dtype=dtype, + device=self.device, + group_name=group_name, + ).fill_(0) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) + + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + x = t[shift : shift + numel] + x.fill_(1) + + res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name) + self.assertTrue(res.eq(self.world_size).all().item()) + dist.destroy_process_group() + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h new file mode 100644 index 000000000000..78d474dc5c7f --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h @@ -0,0 +1,256 @@ +#pragma once + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010 +#define NVCC_SUPPORTS_MULTICAST 1 +#endif + +#include + +namespace c10d::symmetric_memory { + +constexpr size_t max_num_threads_per_block = 1024; +constexpr size_t max_num_blocks = 8; + +template +size_t get_alignment(T ptr_or_size) { + auto val = reinterpret_cast(ptr_or_size); + if (val % 16 == 0) { + return 16; + } else if (val % 8 == 0) { + return 8; + } else if (val % 4 == 0) { + return 4; + } else if (val % 2 == 0) { + return 2; + } else { + return 1; + } +} + +template <> +size_t get_alignment(size_t size) { + return get_alignment(reinterpret_cast(size)); +} + +__device__ __forceinline__ uint32_t +cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + uint32_t old_val; + asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;" + : "=r"(old_val) + : "l"(addr), "r"(compare), "r"(val) + : "memory"); + return old_val; +#endif +} + +__device__ __forceinline__ uint32_t +cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + uint32_t old_val; + asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;" + : "=r"(old_val) + : "l"(addr), "r"(compare), "r"(val) + : "memory"); + return old_val; +#endif +} + +__device__ __forceinline__ void release_signal(uint32_t* addr) { + while (cas_release_sys(addr, 0, 1) != 0) + ; +} + +__device__ __forceinline__ void wait_signal(uint32_t* addr) { + while (cas_sys(addr, 1, 0) != 1) + ; +} + +__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + uint32_t val; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(val) + : "l"(addr) + : "memory"); + return val; +#endif +} + +// Perform a barrier to establish observation order between memory operations +// issued before and after the barrier. +__device__ __forceinline__ void barrier( + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); + } + __syncthreads(); +} + +// Perform a barrier and establish causality order between memory operations +// issued before the calling kernel on all devices and memory operations +// issued after this function by all thread in the calling kernel. +// +// NOTE: this function does NOT ensure that memory operations issues in the +// current kernel are visible to all threads in the current kernel. +// +// | mem ops (guaranteed to be visible by all threads at point T) +// | kernel K +// | +- mem ops (not guaranteed to be visible all threads at point T) +// | +- barrier_and_acquire_previous_kernel_writes() +// | +- point T +// v +__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes( + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + if (threadIdx.x < world_size) { + auto target_rank = threadIdx.x; + release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); + } + __syncthreads(); + // At this point, we established observation order between memory operations + // issued before and after the barrier. Now we convert the observation order + // into causality order by having every thread acquire the signals released + // by threads on peer devices. Due to the implicit synchronizes-with + // relationships at task/kernel boundaries, acquiring the signal released by + // thread T in kernel K transitively acquires memory operations issued prior + // to kernel K. + // + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference + for (size_t target_rank = 0; target_rank < world_size; ++target_rank) { + acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); + } +} + +template +inline constexpr bool dependent_bool_value = Value; + +template +inline constexpr bool dependent_false = dependent_bool_value; + +template +union Vec; + +template <> +union Vec<4> { + uint16_t u16[2]; + uint32_t u32, as_scalar; +}; + +template <> +union Vec<8> { + uint16_t u16[4]; + uint32_t u32[2]; + uint64_t u64, as_scalar; +}; + +template <> +union alignas(16) Vec<16> { + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + uint4 u128, as_scalar; +}; + +template +struct MultimemLdReduce { + template + __device__ __inline__ Vec operator()(T* mc_ptr) { + static_assert(dependent_false); + } +}; + +template +__device__ __inline__ Vec multimem_ld_reduce_add(T* mc_ptr) { + MultimemLdReduce functor; + return functor.template operator()(mc_ptr); +} + +#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) +#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \ + template <> \ + struct MultimemLdReduce { \ + template \ + __device__ __inline__ Vec operator()(type* mc_ptr) { \ + CUDA_KERNEL_ASSERT(false); \ + } \ + }; +#else +#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \ + template <> \ + struct MultimemLdReduce { \ + template \ + __device__ __inline__ Vec operator()(type* mc_ptr) { \ + Vec vec; \ + if constexpr (Alignment == 16) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \ + " {%0,%1,%2,%3}, [%4];" \ + : "=r"(vec.u32[0]), \ + "=r"(vec.u32[1]), \ + "=r"(vec.u32[2]), \ + "=r"(vec.u32[3]) \ + : "l"(mc_ptr) \ + : "memory"); \ + } else if constexpr (Alignment == 8) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \ + " {%0,%1}, [%2];" \ + : "=r"(vec.u32[0]), "=r"(vec.u32[1]) \ + : "l"(mc_ptr) \ + : "memory"); \ + } else if constexpr (Alignment == 4) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \ + : "=r"(vec.u32) \ + : "l"(mc_ptr) \ + : "memory"); \ + } \ + return vec; \ + } \ + }; +#endif + +SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2"); +SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32"); + +template +__device__ __inline__ void multimem_st(T* mc_ptr, Vec& vec) { +#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) + CUDA_KERNEL_ASSERT(false); +#else + if constexpr (Alignment == 16) { + asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" + : + : "l"(mc_ptr), + "r"(vec.u32[0]), + "r"(vec.u32[1]), + "r"(vec.u32[2]), + "r"(vec.u32[3]) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};" + : + : "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1]) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("multimem.st.relaxed.sys.global.f32 [%0], %1;" + : + : "l"(mc_ptr), "r"(vec.u32) + : "memory"); + } else { + static_assert(dependent_false); + } +#endif +} + +} // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu index 13fef18c67c9..17f13c1fcb94 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -14,8 +14,20 @@ #include #include +#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 +#define CUDART_SUPPORTS_MULTICAST +#endif + namespace { +bool has_multicast_support() { +#if defined(CUDART_SUPPORTS_MULTICAST) + return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr; +#else + return false; +#endif +} + class IpcChannel { public: IpcChannel() : socket_name_(get_socket_name(getpid())) { @@ -61,9 +73,7 @@ class IpcChannel { memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd)); TORCH_CHECK( - sendmsg(socket_, &msg, 0) > 0, - "Failed to send fd: ", - strerror(errno)); + sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno)); } int recv_fd() { @@ -110,6 +120,25 @@ class IpcChannel { return fds; } + int broadcast_fds( + int rank, + int src_rank, + const std::vector& pids, + int fd) { + size_t world_size = pids.size(); + + if (rank == src_rank) { + for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) { + if (dst_rank == rank) { + continue; + } + send_fd(pids[dst_rank], fd); + } + return fd; + } + return recv_fd(); + } + private: static std::string get_socket_name(int pid) { const char* tmp_dir = "/tmp"; @@ -213,6 +242,8 @@ CUDASymmetricMemory::CUDASymmetricMemory( size_t block_size, std::vector buffers, std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, size_t buffer_size, int local_device_idx, int rank, @@ -221,6 +252,8 @@ CUDASymmetricMemory::CUDASymmetricMemory( block_size_(block_size), buffers_(std::move(buffers)), signal_pads_(std::move(signal_pads)), + mc_handle_(mc_handle), + mc_addr_(mc_addr), buffer_size_(buffer_size), local_device_idx_(local_device_idx), rank_(rank), @@ -285,6 +318,14 @@ size_t CUDASymmetricMemory::get_signal_pad_size() { return signal_pad_size; } +bool CUDASymmetricMemory::has_multicast_support() { + return ::has_multicast_support(); +} + +void* CUDASymmetricMemory::get_multicast_ptr() { + return mc_addr_; +} + at::Tensor CUDASymmetricMemory::get_buffer( int rank, c10::IntArrayRef sizes, @@ -601,6 +642,46 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( store_barrier(store, rank, world_size); close(block_fd); + CUmemGenericAllocationHandle mc_handle{}; + void* mc_addr = nullptr; +#if defined(CUDART_SUPPORTS_MULTICAST) + // We have to further check if the driver supports multicast + if (has_multicast_support()) { + // Rank 0 creates a multicast object and share it with peers + if (rank == 0) { + CUmulticastObjectProp mc_prop{}; + mc_prop.numDevices = world_size; + mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + mc_prop.size = block->block_size; + + CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop); + TORCH_CHECK(res == CUDA_SUCCESS); + + int mc_fd; + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0)); + ipc_channel.broadcast_fds(rank, 0, pids, mc_fd); + // Ref count is incremented as soon as SCM_RIGHTS send happens + close(mc_fd); + } else { + int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &mc_handle, + (void*)(uintptr_t)mc_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(mc_fd); + } + // All rank adds their physical allocation to the multicast object + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_( + mc_handle, 0, block->handle, 0, block->block_size, 0)); + + map_block(&mc_addr, mc_handle, block->block_size, block->device_idx); + store_barrier(store, rank, world_size); + } +#endif + // Initializing CUDASymmetricMemory with an allocation transfers its // ownership to the CUDASymmetricMemory object. So that outstanding // references to the CUDASymmetricMemory object can keep the allocation @@ -610,6 +691,8 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( block->block_size, std::move(buffers), std::move(signal_pads), + mc_handle, + mc_addr, block->buffer_size, block->device_idx, group_info.rank, @@ -630,6 +713,10 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) { return block->symm_mem != nullptr; } +bool CUDASymmetricMemoryAllocator::has_multicast_support() { + return ::has_multicast_support(); +} + c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block(void* ptr) { std::shared_lock lock(mutex_); auto it = ptr_to_block_.find(ptr); diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp index 82e75d22c84f..caede2a0a491 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp @@ -20,6 +20,8 @@ class CUDASymmetricMemory : public SymmetricMemory { size_t block_size, std::vector buffers, std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, size_t buffer_size, int local_device_idx, int rank, @@ -34,6 +36,9 @@ class CUDASymmetricMemory : public SymmetricMemory { size_t get_buffer_size() override; size_t get_signal_pad_size() override; + bool has_multicast_support() override; + void* get_multicast_ptr() override; + at::Tensor get_buffer( int rank, c10::IntArrayRef sizes, @@ -52,6 +57,8 @@ class CUDASymmetricMemory : public SymmetricMemory { size_t block_size_; std::vector buffers_; std::vector signal_pads_; + HandleType mc_handle_; + void* mc_addr_; size_t buffer_size_; int local_device_idx_; int rank_; @@ -95,6 +102,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { size_t get_alloc_size(void* ptr) override; c10::intrusive_ptr rendezvous(void* ptr) override; bool is_rendezvous_completed(void* ptr) override; + bool has_multicast_support() override; private: c10::intrusive_ptr find_block(void* ptr); diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu new file mode 100644 index 000000000000..cedcca2c9761 --- /dev/null +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -0,0 +1,267 @@ +#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#endif + +#include + +#include +#include + +namespace { + +using namespace c10d::symmetric_memory; + +size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) { + const size_t min_alignment = std::max(4l, input.element_size()); + // Only check the offset since the multicast address is always at least + // 128-bit aligned + const size_t ptr_alignment = get_alignment( + static_cast(input.storage_offset() * input.element_size())); + TORCH_CHECK( + ptr_alignment >= min_alignment, + op_name, + "<", + input.scalar_type(), + ">: input ptr + offset must be at least ", + min_alignment, + "-byte aligned."); + + const size_t size_alignment = + get_alignment(static_cast(input.numel() * input.element_size())); + TORCH_CHECK( + size_alignment >= min_alignment, + op_name, + "<", + input.scalar_type(), + ">: input size must be at least ", + min_alignment, + "-byte aligned."); + return std::min(ptr_alignment, size_alignment); +} + +void init_elementwise_launch_config( + size_t numel, + size_t element_size, + size_t alignment, + size_t splits, + int& num_blocks, + int& num_threads) { + // Align to preserve alignment in each split + const size_t aligned_numel = at::round_up(numel, alignment * splits); + const size_t numel_per_split = aligned_numel / splits; + const size_t numel_per_thread = alignment / element_size; + + if (numel_per_split <= max_num_threads_per_block * numel_per_thread) { + num_blocks = 1; + num_threads = at::round_up( + at::ceil_div(numel_per_split, numel_per_thread), + static_cast(C10_WARP_SIZE)); + } else { + num_blocks = std::min( + at::ceil_div( + numel_per_split, max_num_threads_per_block * numel_per_thread), + max_num_blocks); + num_threads = max_num_threads_per_block; + } +} + +template +static __global__ void multimem_all_reduce_kernel( + T* input_mc_ptr, + size_t numel, + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + static_assert(alignment % sizeof(T) == 0); + constexpr size_t numel_per_thread = alignment / sizeof(T); + + barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size); + + const size_t numel_per_rank = + at::round_up(numel, alignment * world_size) / world_size; + const size_t start = numel_per_rank * rank; + + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; + auto stride = blockDim.x * gridDim.x * numel_per_thread; + for (size_t i = offset; i < numel_per_rank; i += stride) { + if (start + i >= numel) { + continue; + } + auto vec = multimem_ld_reduce_add(input_mc_ptr + start + i); + multimem_st(input_mc_ptr + start + i, vec); + } + // Establish observation order - all writes are in-flight beyond this point. + barrier(signal_pads, rank, world_size); + // Establish causality order - all writes are visible to all devices beyond + // this point. + __threadfence_system(); +} + +at::Tensor multimem_all_reduce_( + const at::Tensor& input, + std::string reduce_op, + std::string group_name) { + TORCH_CHECK( + input.is_contiguous(), "multimem_all_reduce_: input must be contiguous."); + TORCH_CHECK( + reduce_op == "sum", + "multimem_all_reduce_: only sum is supported for now."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input); + TORCH_CHECK( + symm_mem != nullptr, + "multimem_all_reduce_: input must be allocated with empty_strided_p2p()."); + TORCH_CHECK( + symm_mem->has_multicast_support(), + "multimem_all_reduce_: multicast support is required."); + + const size_t alignment = + get_and_verify_alignment(input, "multimem_all_reduce_"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + symm_mem->get_world_size(), + num_blocks, + num_threads); + +#define DISPATCH(scalar_t, kernel_alignment) \ + if (alignment == kernel_alignment) { \ + multimem_all_reduce_kernel \ + <<>>( \ + reinterpret_cast(symm_mem->get_multicast_ptr()) + \ + input.storage_offset(), \ + input.numel(), \ + reinterpret_cast(symm_mem->get_signal_pad_ptrs_dev()), \ + symm_mem->get_rank(), \ + symm_mem->get_world_size()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } + + AT_DISPATCH_SWITCH( + input.scalar_type(), + "multimem_all_reduce", + AT_DISPATCH_CASE(at::kBFloat16, [&] { + DISPATCH(scalar_t, 16); + DISPATCH(scalar_t, 8); + DISPATCH(scalar_t, 4); + }) AT_DISPATCH_CASE(at::kFloat, [&] { + DISPATCH(scalar_t, 16); + DISPATCH(scalar_t, 8); + DISPATCH(scalar_t, 4); + })); + +#undef DISPATCH + return input; +} + +template +static __global__ void multimem_one_shot_all_reduce_kernel( + T* input_mc_ptr, + T* output_ptr, + size_t numel, + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + static_assert(alignment % sizeof(T) == 0); + constexpr size_t numel_per_thread = alignment / sizeof(T); + + barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size); + + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; + auto stride = blockDim.x * gridDim.x * numel_per_thread; + for (size_t i = offset; i < numel; i += stride) { + auto vec = multimem_ld_reduce_add(input_mc_ptr + i); + *reinterpret_cast(output_ptr + i) = vec.as_scalar; + } +} + +at::Tensor multimem_one_shot_all_reduce( + const at::Tensor& input, + std::string reduce_op, + std::string group_name) { + TORCH_CHECK( + input.is_contiguous(), + "multimem_one_shot_all_reduce: input must be contiguous."); + TORCH_CHECK( + reduce_op == "sum", + "multimem_one_shot_all_reduce: only sum is supported for now."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input); + TORCH_CHECK( + symm_mem != nullptr, + "multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p()."); + TORCH_CHECK( + symm_mem->has_multicast_support(), + "multimem_one_shot_all_reduce: requires multicast support."); + + auto output = at::empty_like(input); + + const size_t alignment = + get_and_verify_alignment(input, "multimem_one_shot_all_reduce"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + 1, + num_blocks, + num_threads); + +#define DISPATCH(scalar_t, kernel_alignment) \ + if (alignment == kernel_alignment) { \ + multimem_one_shot_all_reduce_kernel \ + <<>>( \ + reinterpret_cast(symm_mem->get_multicast_ptr()) + \ + input.storage_offset(), \ + output.data_ptr(), \ + input.numel(), \ + reinterpret_cast(symm_mem->get_signal_pad_ptrs_dev()), \ + symm_mem->get_rank(), \ + symm_mem->get_world_size()); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + } + + AT_DISPATCH_SWITCH( + input.scalar_type(), + "multimem_all_reduce", + AT_DISPATCH_CASE(at::kBFloat16, [&] { + DISPATCH(scalar_t, 16); + DISPATCH(scalar_t, 8); + DISPATCH(scalar_t, 4); + }) AT_DISPATCH_CASE(at::kFloat, [&] { + DISPATCH(scalar_t, 16); + DISPATCH(scalar_t, 8); + DISPATCH(scalar_t, 4); + })); + + return output; +} + +TORCH_LIBRARY_FRAGMENT(symm_mem, m) { + m.def( + "multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor", + torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_), + {at::Tag::pt2_compliant_tag}); + + m.def( + "multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", + torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce), + {at::Tag::pt2_compliant_tag}); +} + +} // namespace + +#endif diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 34f794f5a58a..ddfbe3d594f0 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p( TORCH_API c10::intrusive_ptr rendezvous( const at::Tensor& tensor) { auto allocator = get_allocator(tensor.device().type()); - return allocator->rendezvous(tensor.data_ptr()); + return allocator->rendezvous(tensor.storage().data_ptr().get()); } c10::intrusive_ptr get_symmetric_memory( @@ -189,5 +189,9 @@ c10::intrusive_ptr get_symmetric_memory( return allocator->rendezvous(tensor.data_ptr()); } +TORCH_API bool has_multicast_support(c10::DeviceType device_type) { + auto allocator = get_allocator(device_type); + return allocator->has_multicast_support(); +} } // namespace symmetric_memory } // namespace c10d diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp index a9672874f608..babdc6345aae 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -51,6 +51,9 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; + virtual bool has_multicast_support() = 0; + virtual void* get_multicast_ptr() = 0; + virtual at::Tensor get_buffer( int rank, c10::IntArrayRef sizes, @@ -78,6 +81,7 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { virtual size_t get_alloc_size(void* ptr) = 0; virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; virtual bool is_rendezvous_completed(void* ptr) = 0; + virtual bool has_multicast_support() = 0; }; C10_EXPORT bool is_finalizing(); @@ -150,5 +154,6 @@ TORCH_API c10::intrusive_ptr rendezvous( TORCH_API c10::intrusive_ptr get_symmetric_memory( const at::Tensor& tensor); +TORCH_API bool has_multicast_support(c10::DeviceType device_type); } // namespace symmetric_memory } // namespace c10d diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 795113d8f4e8..c8f9dff37f06 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1044,6 +1044,9 @@ This class does not support ``__members__`` property.)"); .def_static( "get_symmetric_memory", &::c10d::symmetric_memory::get_symmetric_memory) + .def_static( + "has_multicast_support", + &::c10d::symmetric_memory::has_multicast_support) .def_property_readonly("rank", &SymmetricMemory::get_rank) .def_property_readonly("world_size", &SymmetricMemory::get_world_size) .def_property_readonly(