[SymmetricMemory] introduce multicast support, multimem_all_reduce_ and multimem_one_shot_all_reduce (#133424)

### Summary
- Added multicast support to SymmetricMemory. If the cuda runtime and cuda driver have multicast support, SymmetricMemory associate all peer buffers with a multicast object and exposes the multicast virtual address.
- Implemented `multimem_all_reduce_` and `multimem_one_shot_all_reduce` based on the multicast support. The two variants shows different performance characteristic for different message size. We plan to use Inductor for collective algo selection (and required symmetric memory buffer allocation).

### Benchmark

8xH100 (non-standard version with HBM2e at 650W). NVSwitch V3 with NVLS support.

![image](https://github.com/user-attachments/assets/4998a16b-c2c0-4797-9dd0-1da2303df947)

![image](https://github.com/user-attachments/assets/278ad361-52cb-4864-82c6-bb67e8d0a3fe)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133424
Approved by: https://github.com/yf225, https://github.com/weifengpy
This commit is contained in:
Yifu Wang
2024-08-20 15:09:12 -07:00
committed by PyTorch MergeBot
parent 8337b4d96e
commit 66d3eb783c
12 changed files with 731 additions and 5 deletions

View File

@ -744,6 +744,7 @@ cc_library(
"torch/csrc/cuda/nccl.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
],

View File

@ -688,6 +688,7 @@ libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
"torch/csrc/distributed/c10d/intra_node_comm.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
"torch/csrc/distributed/c10d/Utils.cu",
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",

View File

@ -20,6 +20,12 @@ DriverAPI create_driver_api() {
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
#undef LOOKUP_LIBCUDA_ENTRY
#define LOOKUP_LIBCUDA_ENTRY(name) \
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
dlerror();
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
#undef LOOKUP_LIBCUDA_ENTRY
if (handle_1) {
#define LOOKUP_NVML_ENTRY(name) \
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \

View File

@ -31,6 +31,15 @@
_(cuMemImportFromShareableHandle) \
_(cuGetErrorString)
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_12030(_) \
_(cuMulticastAddDevice) \
_(cuMulticastBindMem) \
_(cuMulticastCreate)
#else
#define C10_LIBCUDA_DRIVER_API_12030(_)
#endif
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
@ -43,6 +52,7 @@ namespace c10::cuda {
struct DriverAPI {
#define CREATE_MEMBER(name) decltype(&name) name##_;
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
C10_NVML_DRIVER_API(CREATE_MEMBER)
#undef CREATE_MEMBER
static DriverAPI* get();

View File

@ -2,6 +2,7 @@
import torch
import torch.distributed as dist
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback,
@ -44,6 +45,17 @@ def requires_cuda_p2p_access():
)
def requires_multicast_support():
has_multicast_support = (
torch.cuda.is_available()
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
)
return skip_but_pass_in_sandcastle_if(
not has_multicast_support,
"multicast support is not available",
)
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmetricMemoryTest(MultiProcessTestCase):
@ -95,7 +107,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_cuda_nvlink_connectivity_detection(self) -> None:
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _detect_dma_connectivity
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
@ -422,6 +433,73 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
t = _SymmetricMemory.empty_strided_p2p(
size=(16384,),
stride=(1,),
dtype=dtype,
device=self.device,
group_name=group_name,
).fill_(1)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
x = t[shift : shift + numel]
torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
self.assertTrue(x.eq(self.world_size).all().item())
# Head and tail should not be written
self.assertTrue(t[:shift].eq(1).all().item())
self.assertTrue(t[shift + numel :].eq(1).all().item())
dist.destroy_process_group()
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_multimem_one_shot_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
t = _SymmetricMemory.empty_strided_p2p(
size=(16384,),
stride=(1,),
dtype=dtype,
device=self.device,
group_name=group_name,
).fill_(0)
self.assertTrue(t.data_ptr() % 16 == 0)
self.assertTrue(align_bytes % t.element_size() == 0)
self.assertTrue(size_bytes % t.element_size() == 0)
shift = align_bytes // t.element_size()
numel = size_bytes // t.element_size()
x = t[shift : shift + numel]
x.fill_(1)
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
self.assertTrue(res.eq(self.world_size).all().item())
dist.destroy_process_group()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,256 @@
#pragma once
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
#define NVCC_SUPPORTS_MULTICAST 1
#endif
#include <ATen/ATen.h>
namespace c10d::symmetric_memory {
constexpr size_t max_num_threads_per_block = 1024;
constexpr size_t max_num_blocks = 8;
template <typename T>
size_t get_alignment(T ptr_or_size) {
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
if (val % 16 == 0) {
return 16;
} else if (val % 8 == 0) {
return 8;
} else if (val % 4 == 0) {
return 4;
} else if (val % 2 == 0) {
return 2;
} else {
return 1;
}
}
template <>
size_t get_alignment<size_t>(size_t size) {
return get_alignment(reinterpret_cast<void*>(size));
}
__device__ __forceinline__ uint32_t
cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t old_val;
asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
: "=r"(old_val)
: "l"(addr), "r"(compare), "r"(val)
: "memory");
return old_val;
#endif
}
__device__ __forceinline__ uint32_t
cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t old_val;
asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
: "=r"(old_val)
: "l"(addr), "r"(compare), "r"(val)
: "memory");
return old_val;
#endif
}
__device__ __forceinline__ void release_signal(uint32_t* addr) {
while (cas_release_sys(addr, 0, 1) != 0)
;
}
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
while (cas_sys(addr, 1, 0) != 1)
;
}
__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
uint32_t val;
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(val)
: "l"(addr)
: "memory");
return val;
#endif
}
// Perform a barrier to establish observation order between memory operations
// issued before and after the barrier.
__device__ __forceinline__ void barrier(
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
if (threadIdx.x < world_size) {
auto target_rank = threadIdx.x;
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
__syncthreads();
}
// Perform a barrier and establish causality order between memory operations
// issued before the calling kernel on all devices and memory operations
// issued after this function by all thread in the calling kernel.
//
// NOTE: this function does NOT ensure that memory operations issues in the
// current kernel are visible to all threads in the current kernel.
//
// | mem ops (guaranteed to be visible by all threads at point T)
// | kernel K
// | +- mem ops (not guaranteed to be visible all threads at point T)
// | +- barrier_and_acquire_previous_kernel_writes()
// | +- point T
// v
__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
if (threadIdx.x < world_size) {
auto target_rank = threadIdx.x;
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
__syncthreads();
// At this point, we established observation order between memory operations
// issued before and after the barrier. Now we convert the observation order
// into causality order by having every thread acquire the signals released
// by threads on peer devices. Due to the implicit synchronizes-with
// relationships at task/kernel boundaries, acquiring the signal released by
// thread T in kernel K transitively acquires memory operations issued prior
// to kernel K.
//
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
}
}
template <bool Value, class... Args>
inline constexpr bool dependent_bool_value = Value;
template <class... Args>
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
template <int Size>
union Vec;
template <>
union Vec<4> {
uint16_t u16[2];
uint32_t u32, as_scalar;
};
template <>
union Vec<8> {
uint16_t u16[4];
uint32_t u32[2];
uint64_t u64, as_scalar;
};
template <>
union alignas(16) Vec<16> {
uint16_t u16[8];
uint32_t u32[4];
uint64_t u64[2];
uint4 u128, as_scalar;
};
template <typename T>
struct MultimemLdReduce {
template <int Alignment>
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
static_assert(dependent_false<T>);
}
};
template <int Alignment, typename T>
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
MultimemLdReduce<T> functor;
return functor.template operator()<Alignment>(mc_ptr);
}
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
template <> \
struct MultimemLdReduce<type> { \
template <int Alignment> \
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
CUDA_KERNEL_ASSERT(false); \
} \
};
#else
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
template <> \
struct MultimemLdReduce<type> { \
template <int Alignment> \
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
Vec<Alignment> vec; \
if constexpr (Alignment == 16) { \
asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
" {%0,%1,%2,%3}, [%4];" \
: "=r"(vec.u32[0]), \
"=r"(vec.u32[1]), \
"=r"(vec.u32[2]), \
"=r"(vec.u32[3]) \
: "l"(mc_ptr) \
: "memory"); \
} else if constexpr (Alignment == 8) { \
asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
" {%0,%1}, [%2];" \
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
: "l"(mc_ptr) \
: "memory"); \
} else if constexpr (Alignment == 4) { \
asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
: "=r"(vec.u32) \
: "l"(mc_ptr) \
: "memory"); \
} \
return vec; \
} \
};
#endif
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
template <int Alignment, typename T>
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
CUDA_KERNEL_ASSERT(false);
#else
if constexpr (Alignment == 16) {
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
:
: "l"(mc_ptr),
"r"(vec.u32[0]),
"r"(vec.u32[1]),
"r"(vec.u32[2]),
"r"(vec.u32[3])
: "memory");
} else if constexpr (Alignment == 8) {
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
:
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
: "memory");
} else if constexpr (Alignment == 4) {
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
:
: "l"(mc_ptr), "r"(vec.u32)
: "memory");
} else {
static_assert(dependent_false<T>);
}
#endif
}
} // namespace c10d::symmetric_memory

View File

@ -14,8 +14,20 @@
#include <sys/un.h>
#include <unistd.h>
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#define CUDART_SUPPORTS_MULTICAST
#endif
namespace {
bool has_multicast_support() {
#if defined(CUDART_SUPPORTS_MULTICAST)
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
#else
return false;
#endif
}
class IpcChannel {
public:
IpcChannel() : socket_name_(get_socket_name(getpid())) {
@ -61,9 +73,7 @@ class IpcChannel {
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0,
"Failed to send fd: ",
strerror(errno));
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
}
int recv_fd() {
@ -110,6 +120,25 @@ class IpcChannel {
return fds;
}
int broadcast_fds(
int rank,
int src_rank,
const std::vector<int>& pids,
int fd) {
size_t world_size = pids.size();
if (rank == src_rank) {
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
if (dst_rank == rank) {
continue;
}
send_fd(pids[dst_rank], fd);
}
return fd;
}
return recv_fd();
}
private:
static std::string get_socket_name(int pid) {
const char* tmp_dir = "/tmp";
@ -213,6 +242,8 @@ CUDASymmetricMemory::CUDASymmetricMemory(
size_t block_size,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
void* mc_addr,
size_t buffer_size,
int local_device_idx,
int rank,
@ -221,6 +252,8 @@ CUDASymmetricMemory::CUDASymmetricMemory(
block_size_(block_size),
buffers_(std::move(buffers)),
signal_pads_(std::move(signal_pads)),
mc_handle_(mc_handle),
mc_addr_(mc_addr),
buffer_size_(buffer_size),
local_device_idx_(local_device_idx),
rank_(rank),
@ -285,6 +318,14 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
return signal_pad_size;
}
bool CUDASymmetricMemory::has_multicast_support() {
return ::has_multicast_support();
}
void* CUDASymmetricMemory::get_multicast_ptr() {
return mc_addr_;
}
at::Tensor CUDASymmetricMemory::get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -601,6 +642,46 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
store_barrier(store, rank, world_size);
close(block_fd);
CUmemGenericAllocationHandle mc_handle{};
void* mc_addr = nullptr;
#if defined(CUDART_SUPPORTS_MULTICAST)
// We have to further check if the driver supports multicast
if (has_multicast_support()) {
// Rank 0 creates a multicast object and share it with peers
if (rank == 0) {
CUmulticastObjectProp mc_prop{};
mc_prop.numDevices = world_size;
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
mc_prop.size = block->block_size;
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
TORCH_CHECK(res == CUDA_SUCCESS);
int mc_fd;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
// Ref count is incremented as soon as SCM_RIGHTS send happens
close(mc_fd);
} else {
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
&mc_handle,
(void*)(uintptr_t)mc_fd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(mc_fd);
}
// All rank adds their physical allocation to the multicast object
C10_CUDA_DRIVER_CHECK(
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
mc_handle, 0, block->handle, 0, block->block_size, 0));
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
store_barrier(store, rank, world_size);
}
#endif
// Initializing CUDASymmetricMemory with an allocation transfers its
// ownership to the CUDASymmetricMemory object. So that outstanding
// references to the CUDASymmetricMemory object can keep the allocation
@ -610,6 +691,8 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
block->block_size,
std::move(buffers),
std::move(signal_pads),
mc_handle,
mc_addr,
block->buffer_size,
block->device_idx,
group_info.rank,
@ -630,6 +713,10 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
return block->symm_mem != nullptr;
}
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
return ::has_multicast_support();
}
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
std::shared_lock lock(mutex_);
auto it = ptr_to_block_.find(ptr);

View File

@ -20,6 +20,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t block_size,
std::vector<void*> buffers,
std::vector<void*> signal_pads,
HandleType mc_handle,
void* mc_addr,
size_t buffer_size,
int local_device_idx,
int rank,
@ -34,6 +36,9 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t get_buffer_size() override;
size_t get_signal_pad_size() override;
bool has_multicast_support() override;
void* get_multicast_ptr() override;
at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -52,6 +57,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
size_t block_size_;
std::vector<void*> buffers_;
std::vector<void*> signal_pads_;
HandleType mc_handle_;
void* mc_addr_;
size_t buffer_size_;
int local_device_idx_;
int rank_;
@ -95,6 +102,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
size_t get_alloc_size(void* ptr) override;
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
bool is_rendezvous_completed(void* ptr) override;
bool has_multicast_support() override;
private:
c10::intrusive_ptr<Block> find_block(void* ptr);

View File

@ -0,0 +1,267 @@
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#include <ATen/ATen.h>
#include <ATen/ceil_div.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <torch/library.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
namespace {
using namespace c10d::symmetric_memory;
size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
const size_t min_alignment = std::max(4l, input.element_size());
// Only check the offset since the multicast address is always at least
// 128-bit aligned
const size_t ptr_alignment = get_alignment(
static_cast<size_t>(input.storage_offset() * input.element_size()));
TORCH_CHECK(
ptr_alignment >= min_alignment,
op_name,
"<",
input.scalar_type(),
">: input ptr + offset must be at least ",
min_alignment,
"-byte aligned.");
const size_t size_alignment =
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
TORCH_CHECK(
size_alignment >= min_alignment,
op_name,
"<",
input.scalar_type(),
">: input size must be at least ",
min_alignment,
"-byte aligned.");
return std::min(ptr_alignment, size_alignment);
}
void init_elementwise_launch_config(
size_t numel,
size_t element_size,
size_t alignment,
size_t splits,
int& num_blocks,
int& num_threads) {
// Align to preserve alignment in each split
const size_t aligned_numel = at::round_up(numel, alignment * splits);
const size_t numel_per_split = aligned_numel / splits;
const size_t numel_per_thread = alignment / element_size;
if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
num_blocks = 1;
num_threads = at::round_up(
at::ceil_div(numel_per_split, numel_per_thread),
static_cast<size_t>(C10_WARP_SIZE));
} else {
num_blocks = std::min(
at::ceil_div(
numel_per_split, max_num_threads_per_block * numel_per_thread),
max_num_blocks);
num_threads = max_num_threads_per_block;
}
}
template <typename T, int alignment>
static __global__ void multimem_all_reduce_kernel(
T* input_mc_ptr,
size_t numel,
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
static_assert(alignment % sizeof(T) == 0);
constexpr size_t numel_per_thread = alignment / sizeof(T);
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
const size_t numel_per_rank =
at::round_up(numel, alignment * world_size) / world_size;
const size_t start = numel_per_rank * rank;
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
auto stride = blockDim.x * gridDim.x * numel_per_thread;
for (size_t i = offset; i < numel_per_rank; i += stride) {
if (start + i >= numel) {
continue;
}
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
multimem_st<alignment>(input_mc_ptr + start + i, vec);
}
// Establish observation order - all writes are in-flight beyond this point.
barrier(signal_pads, rank, world_size);
// Establish causality order - all writes are visible to all devices beyond
// this point.
__threadfence_system();
}
at::Tensor multimem_all_reduce_(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
TORCH_CHECK(
input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
TORCH_CHECK(
reduce_op == "sum",
"multimem_all_reduce_: only sum is supported for now.");
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
TORCH_CHECK(
symm_mem != nullptr,
"multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
TORCH_CHECK(
symm_mem->has_multicast_support(),
"multimem_all_reduce_: multicast support is required.");
const size_t alignment =
get_and_verify_alignment(input, "multimem_all_reduce_");
int num_blocks = 0, num_threads = 0;
init_elementwise_launch_config(
input.numel(),
input.element_size(),
alignment,
symm_mem->get_world_size(),
num_blocks,
num_threads);
#define DISPATCH(scalar_t, kernel_alignment) \
if (alignment == kernel_alignment) { \
multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
input.storage_offset(), \
input.numel(), \
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
symm_mem->get_rank(), \
symm_mem->get_world_size()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
AT_DISPATCH_SWITCH(
input.scalar_type(),
"multimem_all_reduce",
AT_DISPATCH_CASE(at::kBFloat16, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}) AT_DISPATCH_CASE(at::kFloat, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}));
#undef DISPATCH
return input;
}
template <typename T, int alignment>
static __global__ void multimem_one_shot_all_reduce_kernel(
T* input_mc_ptr,
T* output_ptr,
size_t numel,
uint32_t** signal_pads,
size_t rank,
size_t world_size) {
static_assert(alignment % sizeof(T) == 0);
constexpr size_t numel_per_thread = alignment / sizeof(T);
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
auto stride = blockDim.x * gridDim.x * numel_per_thread;
for (size_t i = offset; i < numel; i += stride) {
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
*reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
}
}
at::Tensor multimem_one_shot_all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
TORCH_CHECK(
input.is_contiguous(),
"multimem_one_shot_all_reduce: input must be contiguous.");
TORCH_CHECK(
reduce_op == "sum",
"multimem_one_shot_all_reduce: only sum is supported for now.");
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
TORCH_CHECK(
symm_mem != nullptr,
"multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
TORCH_CHECK(
symm_mem->has_multicast_support(),
"multimem_one_shot_all_reduce: requires multicast support.");
auto output = at::empty_like(input);
const size_t alignment =
get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
int num_blocks = 0, num_threads = 0;
init_elementwise_launch_config(
input.numel(),
input.element_size(),
alignment,
1,
num_blocks,
num_threads);
#define DISPATCH(scalar_t, kernel_alignment) \
if (alignment == kernel_alignment) { \
multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
input.storage_offset(), \
output.data_ptr<scalar_t>(), \
input.numel(), \
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
symm_mem->get_rank(), \
symm_mem->get_world_size()); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
}
AT_DISPATCH_SWITCH(
input.scalar_type(),
"multimem_all_reduce",
AT_DISPATCH_CASE(at::kBFloat16, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}) AT_DISPATCH_CASE(at::kFloat, [&] {
DISPATCH(scalar_t, 16);
DISPATCH(scalar_t, 8);
DISPATCH(scalar_t, 4);
}));
return output;
}
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
{at::Tag::pt2_compliant_tag});
m.def(
"multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
{at::Tag::pt2_compliant_tag});
}
} // namespace
#endif

View File

@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p(
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor) {
auto allocator = get_allocator(tensor.device().type());
return allocator->rendezvous(tensor.data_ptr());
return allocator->rendezvous(tensor.storage().data_ptr().get());
}
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
@ -189,5 +189,9 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
return allocator->rendezvous(tensor.data_ptr());
}
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
auto allocator = get_allocator(device_type);
return allocator->has_multicast_support();
}
} // namespace symmetric_memory
} // namespace c10d

View File

@ -51,6 +51,9 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
virtual size_t get_buffer_size() = 0;
virtual size_t get_signal_pad_size() = 0;
virtual bool has_multicast_support() = 0;
virtual void* get_multicast_ptr() = 0;
virtual at::Tensor get_buffer(
int rank,
c10::IntArrayRef sizes,
@ -78,6 +81,7 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
virtual size_t get_alloc_size(void* ptr) = 0;
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
virtual bool is_rendezvous_completed(void* ptr) = 0;
virtual bool has_multicast_support() = 0;
};
C10_EXPORT bool is_finalizing();
@ -150,5 +154,6 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor);
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
} // namespace symmetric_memory
} // namespace c10d

View File

@ -1044,6 +1044,9 @@ This class does not support ``__members__`` property.)");
.def_static(
"get_symmetric_memory",
&::c10d::symmetric_memory::get_symmetric_memory)
.def_static(
"has_multicast_support",
&::c10d::symmetric_memory::has_multicast_support)
.def_property_readonly("rank", &SymmetricMemory::get_rank)
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
.def_property_readonly(