mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[SymmetricMemory] introduce multicast support, multimem_all_reduce_ and multimem_one_shot_all_reduce (#133424)
### Summary - Added multicast support to SymmetricMemory. If the cuda runtime and cuda driver have multicast support, SymmetricMemory associate all peer buffers with a multicast object and exposes the multicast virtual address. - Implemented `multimem_all_reduce_` and `multimem_one_shot_all_reduce` based on the multicast support. The two variants shows different performance characteristic for different message size. We plan to use Inductor for collective algo selection (and required symmetric memory buffer allocation). ### Benchmark 8xH100 (non-standard version with HBM2e at 650W). NVSwitch V3 with NVLS support.   Differential Revision: [D61682507](https://our.internmc.facebook.com/intern/diff/D61682507) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133424 Approved by: https://github.com/yf225, https://github.com/weifengpy
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ca7f0fc5c
commit
78d69bfe11
@ -744,6 +744,7 @@ cc_library(
|
||||
"torch/csrc/cuda/nccl.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
],
|
||||
|
@ -688,6 +688,7 @@ libtorch_cuda_distributed_extra_sources = [
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cpp",
|
||||
"torch/csrc/distributed/c10d/intra_node_comm.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
|
||||
"torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
|
||||
"torch/csrc/distributed/c10d/Utils.cu",
|
||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
|
||||
|
@ -20,6 +20,12 @@ DriverAPI create_driver_api() {
|
||||
C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
|
||||
#undef LOOKUP_LIBCUDA_ENTRY
|
||||
|
||||
#define LOOKUP_LIBCUDA_ENTRY(name) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
|
||||
dlerror();
|
||||
C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
|
||||
#undef LOOKUP_LIBCUDA_ENTRY
|
||||
|
||||
if (handle_1) {
|
||||
#define LOOKUP_NVML_ENTRY(name) \
|
||||
r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
|
||||
|
@ -31,6 +31,15 @@
|
||||
_(cuMemImportFromShareableHandle) \
|
||||
_(cuGetErrorString)
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_12030(_) \
|
||||
_(cuMulticastAddDevice) \
|
||||
_(cuMulticastBindMem) \
|
||||
_(cuMulticastCreate)
|
||||
#else
|
||||
#define C10_LIBCUDA_DRIVER_API_12030(_)
|
||||
#endif
|
||||
|
||||
#define C10_NVML_DRIVER_API(_) \
|
||||
_(nvmlInit_v2) \
|
||||
_(nvmlDeviceGetHandleByPciBusId_v2) \
|
||||
@ -43,6 +52,7 @@ namespace c10::cuda {
|
||||
struct DriverAPI {
|
||||
#define CREATE_MEMBER(name) decltype(&name) name##_;
|
||||
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
|
||||
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
|
||||
C10_NVML_DRIVER_API(CREATE_MEMBER)
|
||||
#undef CREATE_MEMBER
|
||||
static DriverAPI* get();
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _SymmetricMemory
|
||||
from torch.distributed._symmetric_memory import (
|
||||
_fused_all_gather_matmul_fallback,
|
||||
@ -44,6 +45,17 @@ def requires_cuda_p2p_access():
|
||||
)
|
||||
|
||||
|
||||
def requires_multicast_support():
|
||||
has_multicast_support = (
|
||||
torch.cuda.is_available()
|
||||
and _SymmetricMemory.has_multicast_support(DeviceType.CUDA)
|
||||
)
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not has_multicast_support,
|
||||
"multicast support is not available",
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
@ -95,7 +107,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cuda_nvlink_connectivity_detection(self) -> None:
|
||||
from torch._C._autograd import DeviceType
|
||||
from torch._C._distributed_c10d import _detect_dma_connectivity
|
||||
|
||||
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
|
||||
@ -422,6 +433,73 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_multicast_support()
|
||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||
@parametrize("align_bytes", [4, 8, 16])
|
||||
@parametrize("size_bytes", [4, 8192, 8196])
|
||||
def test_multimem_all_reduce(
|
||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
||||
) -> None:
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(16384,),
|
||||
stride=(1,),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
group_name=group_name,
|
||||
).fill_(1)
|
||||
|
||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
||||
|
||||
shift = align_bytes // t.element_size()
|
||||
numel = size_bytes // t.element_size()
|
||||
x = t[shift : shift + numel]
|
||||
|
||||
torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name)
|
||||
self.assertTrue(x.eq(self.world_size).all().item())
|
||||
|
||||
# Head and tail should not be written
|
||||
self.assertTrue(t[:shift].eq(1).all().item())
|
||||
self.assertTrue(t[shift + numel :].eq(1).all().item())
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_multicast_support()
|
||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||
@parametrize("align_bytes", [4, 8, 16])
|
||||
@parametrize("size_bytes", [4, 8192, 8196])
|
||||
def test_multimem_one_shot_all_reduce(
|
||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
||||
) -> None:
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
size=(16384,),
|
||||
stride=(1,),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
group_name=group_name,
|
||||
).fill_(0)
|
||||
|
||||
self.assertTrue(t.data_ptr() % 16 == 0)
|
||||
self.assertTrue(align_bytes % t.element_size() == 0)
|
||||
self.assertTrue(size_bytes % t.element_size() == 0)
|
||||
|
||||
shift = align_bytes // t.element_size()
|
||||
numel = size_bytes // t.element_size()
|
||||
x = t[shift : shift + numel]
|
||||
x.fill_(1)
|
||||
|
||||
res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name)
|
||||
self.assertTrue(res.eq(self.world_size).all().item())
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
256
torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h
Normal file
256
torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h
Normal file
@ -0,0 +1,256 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
|
||||
#define NVCC_SUPPORTS_MULTICAST 1
|
||||
#endif
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace c10d::symmetric_memory {
|
||||
|
||||
constexpr size_t max_num_threads_per_block = 1024;
|
||||
constexpr size_t max_num_blocks = 8;
|
||||
|
||||
template <typename T>
|
||||
size_t get_alignment(T ptr_or_size) {
|
||||
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
|
||||
if (val % 16 == 0) {
|
||||
return 16;
|
||||
} else if (val % 8 == 0) {
|
||||
return 8;
|
||||
} else if (val % 4 == 0) {
|
||||
return 4;
|
||||
} else if (val % 2 == 0) {
|
||||
return 2;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
size_t get_alignment<size_t>(size_t size) {
|
||||
return get_alignment(reinterpret_cast<void*>(size));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t
|
||||
cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t old_val;
|
||||
asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;"
|
||||
: "=r"(old_val)
|
||||
: "l"(addr), "r"(compare), "r"(val)
|
||||
: "memory");
|
||||
return old_val;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t
|
||||
cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t old_val;
|
||||
asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;"
|
||||
: "=r"(old_val)
|
||||
: "l"(addr), "r"(compare), "r"(val)
|
||||
: "memory");
|
||||
return old_val;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void release_signal(uint32_t* addr) {
|
||||
while (cas_release_sys(addr, 0, 1) != 0)
|
||||
;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
|
||||
while (cas_sys(addr, 1, 0) != 1)
|
||||
;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) {
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
uint32_t val;
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||
: "=r"(val)
|
||||
: "l"(addr)
|
||||
: "memory");
|
||||
return val;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Perform a barrier to establish observation order between memory operations
|
||||
// issued before and after the barrier.
|
||||
__device__ __forceinline__ void barrier(
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
if (threadIdx.x < world_size) {
|
||||
auto target_rank = threadIdx.x;
|
||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Perform a barrier and establish causality order between memory operations
|
||||
// issued before the calling kernel on all devices and memory operations
|
||||
// issued after this function by all thread in the calling kernel.
|
||||
//
|
||||
// NOTE: this function does NOT ensure that memory operations issues in the
|
||||
// current kernel are visible to all threads in the current kernel.
|
||||
//
|
||||
// | mem ops (guaranteed to be visible by all threads at point T)
|
||||
// | kernel K
|
||||
// | +- mem ops (not guaranteed to be visible all threads at point T)
|
||||
// | +- barrier_and_acquire_previous_kernel_writes()
|
||||
// | +- point T
|
||||
// v
|
||||
__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes(
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
if (threadIdx.x < world_size) {
|
||||
auto target_rank = threadIdx.x;
|
||||
release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
||||
wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
__syncthreads();
|
||||
// At this point, we established observation order between memory operations
|
||||
// issued before and after the barrier. Now we convert the observation order
|
||||
// into causality order by having every thread acquire the signals released
|
||||
// by threads on peer devices. Due to the implicit synchronizes-with
|
||||
// relationships at task/kernel boundaries, acquiring the signal released by
|
||||
// thread T in kernel K transitively acquires memory operations issued prior
|
||||
// to kernel K.
|
||||
//
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference
|
||||
for (size_t target_rank = 0; target_rank < world_size; ++target_rank) {
|
||||
acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Value, class... Args>
|
||||
inline constexpr bool dependent_bool_value = Value;
|
||||
|
||||
template <class... Args>
|
||||
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
|
||||
|
||||
template <int Size>
|
||||
union Vec;
|
||||
|
||||
template <>
|
||||
union Vec<4> {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32, as_scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
union Vec<8> {
|
||||
uint16_t u16[4];
|
||||
uint32_t u32[2];
|
||||
uint64_t u64, as_scalar;
|
||||
};
|
||||
|
||||
template <>
|
||||
union alignas(16) Vec<16> {
|
||||
uint16_t u16[8];
|
||||
uint32_t u32[4];
|
||||
uint64_t u64[2];
|
||||
uint4 u128, as_scalar;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MultimemLdReduce {
|
||||
template <int Alignment>
|
||||
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
|
||||
static_assert(dependent_false<T>);
|
||||
}
|
||||
};
|
||||
|
||||
template <int Alignment, typename T>
|
||||
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
|
||||
MultimemLdReduce<T> functor;
|
||||
return functor.template operator()<Alignment>(mc_ptr);
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
||||
template <> \
|
||||
struct MultimemLdReduce<type> { \
|
||||
template <int Alignment> \
|
||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
||||
CUDA_KERNEL_ASSERT(false); \
|
||||
} \
|
||||
};
|
||||
#else
|
||||
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \
|
||||
template <> \
|
||||
struct MultimemLdReduce<type> { \
|
||||
template <int Alignment> \
|
||||
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
||||
Vec<Alignment> vec; \
|
||||
if constexpr (Alignment == 16) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \
|
||||
" {%0,%1,%2,%3}, [%4];" \
|
||||
: "=r"(vec.u32[0]), \
|
||||
"=r"(vec.u32[1]), \
|
||||
"=r"(vec.u32[2]), \
|
||||
"=r"(vec.u32[3]) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} else if constexpr (Alignment == 8) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \
|
||||
" {%0,%1}, [%2];" \
|
||||
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} else if constexpr (Alignment == 4) { \
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \
|
||||
: "=r"(vec.u32) \
|
||||
: "l"(mc_ptr) \
|
||||
: "memory"); \
|
||||
} \
|
||||
return vec; \
|
||||
} \
|
||||
};
|
||||
#endif
|
||||
|
||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2");
|
||||
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32");
|
||||
|
||||
template <int Alignment, typename T>
|
||||
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
|
||||
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
#else
|
||||
if constexpr (Alignment == 16) {
|
||||
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
|
||||
:
|
||||
: "l"(mc_ptr),
|
||||
"r"(vec.u32[0]),
|
||||
"r"(vec.u32[1]),
|
||||
"r"(vec.u32[2]),
|
||||
"r"(vec.u32[3])
|
||||
: "memory");
|
||||
} else if constexpr (Alignment == 8) {
|
||||
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
|
||||
:
|
||||
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
|
||||
: "memory");
|
||||
} else if constexpr (Alignment == 4) {
|
||||
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
|
||||
:
|
||||
: "l"(mc_ptr), "r"(vec.u32)
|
||||
: "memory");
|
||||
} else {
|
||||
static_assert(dependent_false<T>);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10d::symmetric_memory
|
@ -14,8 +14,20 @@
|
||||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
#define CUDART_SUPPORTS_MULTICAST
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
bool has_multicast_support() {
|
||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
||||
return c10::cuda::DriverAPI::get()->cuMulticastCreate_ != nullptr;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
class IpcChannel {
|
||||
public:
|
||||
IpcChannel() : socket_name_(get_socket_name(getpid())) {
|
||||
@ -61,9 +73,7 @@ class IpcChannel {
|
||||
memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
|
||||
|
||||
TORCH_CHECK(
|
||||
sendmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to send fd: ",
|
||||
strerror(errno));
|
||||
sendmsg(socket_, &msg, 0) > 0, "Failed to send fd: ", strerror(errno));
|
||||
}
|
||||
|
||||
int recv_fd() {
|
||||
@ -110,6 +120,25 @@ class IpcChannel {
|
||||
return fds;
|
||||
}
|
||||
|
||||
int broadcast_fds(
|
||||
int rank,
|
||||
int src_rank,
|
||||
const std::vector<int>& pids,
|
||||
int fd) {
|
||||
size_t world_size = pids.size();
|
||||
|
||||
if (rank == src_rank) {
|
||||
for (int dst_rank = 0; dst_rank < (int)world_size; ++dst_rank) {
|
||||
if (dst_rank == rank) {
|
||||
continue;
|
||||
}
|
||||
send_fd(pids[dst_rank], fd);
|
||||
}
|
||||
return fd;
|
||||
}
|
||||
return recv_fd();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string get_socket_name(int pid) {
|
||||
const char* tmp_dir = "/tmp";
|
||||
@ -213,6 +242,8 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
||||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
void* mc_addr,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
@ -221,6 +252,8 @@ CUDASymmetricMemory::CUDASymmetricMemory(
|
||||
block_size_(block_size),
|
||||
buffers_(std::move(buffers)),
|
||||
signal_pads_(std::move(signal_pads)),
|
||||
mc_handle_(mc_handle),
|
||||
mc_addr_(mc_addr),
|
||||
buffer_size_(buffer_size),
|
||||
local_device_idx_(local_device_idx),
|
||||
rank_(rank),
|
||||
@ -285,6 +318,14 @@ size_t CUDASymmetricMemory::get_signal_pad_size() {
|
||||
return signal_pad_size;
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemory::has_multicast_support() {
|
||||
return ::has_multicast_support();
|
||||
}
|
||||
|
||||
void* CUDASymmetricMemory::get_multicast_ptr() {
|
||||
return mc_addr_;
|
||||
}
|
||||
|
||||
at::Tensor CUDASymmetricMemory::get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
@ -601,6 +642,46 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
store_barrier(store, rank, world_size);
|
||||
close(block_fd);
|
||||
|
||||
CUmemGenericAllocationHandle mc_handle{};
|
||||
void* mc_addr = nullptr;
|
||||
#if defined(CUDART_SUPPORTS_MULTICAST)
|
||||
// We have to further check if the driver supports multicast
|
||||
if (has_multicast_support()) {
|
||||
// Rank 0 creates a multicast object and share it with peers
|
||||
if (rank == 0) {
|
||||
CUmulticastObjectProp mc_prop{};
|
||||
mc_prop.numDevices = world_size;
|
||||
mc_prop.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
mc_prop.size = block->block_size;
|
||||
|
||||
CUresult res = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
|
||||
TORCH_CHECK(res == CUDA_SUCCESS);
|
||||
|
||||
int mc_fd;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
||||
&mc_fd, mc_handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
|
||||
ipc_channel.broadcast_fds(rank, 0, pids, mc_fd);
|
||||
// Ref count is incremented as soon as SCM_RIGHTS send happens
|
||||
close(mc_fd);
|
||||
} else {
|
||||
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
|
||||
&mc_handle,
|
||||
(void*)(uintptr_t)mc_fd,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
close(mc_fd);
|
||||
}
|
||||
// All rank adds their physical allocation to the multicast object
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
|
||||
mc_handle, 0, block->handle, 0, block->block_size, 0));
|
||||
|
||||
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
|
||||
store_barrier(store, rank, world_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Initializing CUDASymmetricMemory with an allocation transfers its
|
||||
// ownership to the CUDASymmetricMemory object. So that outstanding
|
||||
// references to the CUDASymmetricMemory object can keep the allocation
|
||||
@ -610,6 +691,8 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
block->block_size,
|
||||
std::move(buffers),
|
||||
std::move(signal_pads),
|
||||
mc_handle,
|
||||
mc_addr,
|
||||
block->buffer_size,
|
||||
block->device_idx,
|
||||
group_info.rank,
|
||||
@ -630,6 +713,10 @@ bool CUDASymmetricMemoryAllocator::is_rendezvous_completed(void* ptr) {
|
||||
return block->symm_mem != nullptr;
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::has_multicast_support() {
|
||||
return ::has_multicast_support();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Block> CUDASymmetricMemoryAllocator::find_block(void* ptr) {
|
||||
std::shared_lock lock(mutex_);
|
||||
auto it = ptr_to_block_.find(ptr);
|
||||
|
@ -20,6 +20,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||
size_t block_size,
|
||||
std::vector<void*> buffers,
|
||||
std::vector<void*> signal_pads,
|
||||
HandleType mc_handle,
|
||||
void* mc_addr,
|
||||
size_t buffer_size,
|
||||
int local_device_idx,
|
||||
int rank,
|
||||
@ -34,6 +36,9 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||
size_t get_buffer_size() override;
|
||||
size_t get_signal_pad_size() override;
|
||||
|
||||
bool has_multicast_support() override;
|
||||
void* get_multicast_ptr() override;
|
||||
|
||||
at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
@ -52,6 +57,8 @@ class CUDASymmetricMemory : public SymmetricMemory {
|
||||
size_t block_size_;
|
||||
std::vector<void*> buffers_;
|
||||
std::vector<void*> signal_pads_;
|
||||
HandleType mc_handle_;
|
||||
void* mc_addr_;
|
||||
size_t buffer_size_;
|
||||
int local_device_idx_;
|
||||
int rank_;
|
||||
@ -95,6 +102,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||
size_t get_alloc_size(void* ptr) override;
|
||||
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
||||
bool is_rendezvous_completed(void* ptr) override;
|
||||
bool has_multicast_support() override;
|
||||
|
||||
private:
|
||||
c10::intrusive_ptr<Block> find_block(void* ptr);
|
||||
|
267
torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
Normal file
267
torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
Normal file
@ -0,0 +1,267 @@
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#endif
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace c10d::symmetric_memory;
|
||||
|
||||
size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
|
||||
const size_t min_alignment = std::max(4l, input.element_size());
|
||||
// Only check the offset since the multicast address is always at least
|
||||
// 128-bit aligned
|
||||
const size_t ptr_alignment = get_alignment(
|
||||
static_cast<size_t>(input.storage_offset() * input.element_size()));
|
||||
TORCH_CHECK(
|
||||
ptr_alignment >= min_alignment,
|
||||
op_name,
|
||||
"<",
|
||||
input.scalar_type(),
|
||||
">: input ptr + offset must be at least ",
|
||||
min_alignment,
|
||||
"-byte aligned.");
|
||||
|
||||
const size_t size_alignment =
|
||||
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
|
||||
TORCH_CHECK(
|
||||
size_alignment >= min_alignment,
|
||||
op_name,
|
||||
"<",
|
||||
input.scalar_type(),
|
||||
">: input size must be at least ",
|
||||
min_alignment,
|
||||
"-byte aligned.");
|
||||
return std::min(ptr_alignment, size_alignment);
|
||||
}
|
||||
|
||||
void init_elementwise_launch_config(
|
||||
size_t numel,
|
||||
size_t element_size,
|
||||
size_t alignment,
|
||||
size_t splits,
|
||||
int& num_blocks,
|
||||
int& num_threads) {
|
||||
// Align to preserve alignment in each split
|
||||
const size_t aligned_numel = at::round_up(numel, alignment * splits);
|
||||
const size_t numel_per_split = aligned_numel / splits;
|
||||
const size_t numel_per_thread = alignment / element_size;
|
||||
|
||||
if (numel_per_split <= max_num_threads_per_block * numel_per_thread) {
|
||||
num_blocks = 1;
|
||||
num_threads = at::round_up(
|
||||
at::ceil_div(numel_per_split, numel_per_thread),
|
||||
static_cast<size_t>(C10_WARP_SIZE));
|
||||
} else {
|
||||
num_blocks = std::min(
|
||||
at::ceil_div(
|
||||
numel_per_split, max_num_threads_per_block * numel_per_thread),
|
||||
max_num_blocks);
|
||||
num_threads = max_num_threads_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int alignment>
|
||||
static __global__ void multimem_all_reduce_kernel(
|
||||
T* input_mc_ptr,
|
||||
size_t numel,
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
||||
|
||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
||||
|
||||
const size_t numel_per_rank =
|
||||
at::round_up(numel, alignment * world_size) / world_size;
|
||||
const size_t start = numel_per_rank * rank;
|
||||
|
||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
||||
for (size_t i = offset; i < numel_per_rank; i += stride) {
|
||||
if (start + i >= numel) {
|
||||
continue;
|
||||
}
|
||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + start + i);
|
||||
multimem_st<alignment>(input_mc_ptr + start + i, vec);
|
||||
}
|
||||
// Establish observation order - all writes are in-flight beyond this point.
|
||||
barrier(signal_pads, rank, world_size);
|
||||
// Establish causality order - all writes are visible to all devices beyond
|
||||
// this point.
|
||||
__threadfence_system();
|
||||
}
|
||||
|
||||
at::Tensor multimem_all_reduce_(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
TORCH_CHECK(
|
||||
input.is_contiguous(), "multimem_all_reduce_: input must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
reduce_op == "sum",
|
||||
"multimem_all_reduce_: only sum is supported for now.");
|
||||
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
||||
TORCH_CHECK(
|
||||
symm_mem != nullptr,
|
||||
"multimem_all_reduce_: input must be allocated with empty_strided_p2p().");
|
||||
TORCH_CHECK(
|
||||
symm_mem->has_multicast_support(),
|
||||
"multimem_all_reduce_: multicast support is required.");
|
||||
|
||||
const size_t alignment =
|
||||
get_and_verify_alignment(input, "multimem_all_reduce_");
|
||||
|
||||
int num_blocks = 0, num_threads = 0;
|
||||
init_elementwise_launch_config(
|
||||
input.numel(),
|
||||
input.element_size(),
|
||||
alignment,
|
||||
symm_mem->get_world_size(),
|
||||
num_blocks,
|
||||
num_threads);
|
||||
|
||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
||||
if (alignment == kernel_alignment) { \
|
||||
multimem_all_reduce_kernel<scalar_t, kernel_alignment> \
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
||||
input.storage_offset(), \
|
||||
input.numel(), \
|
||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
||||
symm_mem->get_rank(), \
|
||||
symm_mem->get_world_size()); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
|
||||
AT_DISPATCH_SWITCH(
|
||||
input.scalar_type(),
|
||||
"multimem_all_reduce",
|
||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}));
|
||||
|
||||
#undef DISPATCH
|
||||
return input;
|
||||
}
|
||||
|
||||
template <typename T, int alignment>
|
||||
static __global__ void multimem_one_shot_all_reduce_kernel(
|
||||
T* input_mc_ptr,
|
||||
T* output_ptr,
|
||||
size_t numel,
|
||||
uint32_t** signal_pads,
|
||||
size_t rank,
|
||||
size_t world_size) {
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
||||
|
||||
barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size);
|
||||
|
||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
||||
for (size_t i = offset; i < numel; i += stride) {
|
||||
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
|
||||
*reinterpret_cast<decltype(vec.as_scalar)*>(output_ptr + i) = vec.as_scalar;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor multimem_one_shot_all_reduce(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
TORCH_CHECK(
|
||||
input.is_contiguous(),
|
||||
"multimem_one_shot_all_reduce: input must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
reduce_op == "sum",
|
||||
"multimem_one_shot_all_reduce: only sum is supported for now.");
|
||||
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input);
|
||||
TORCH_CHECK(
|
||||
symm_mem != nullptr,
|
||||
"multimem_one_shot_all_reduce: input must be allocated with empty_strided_p2p().");
|
||||
TORCH_CHECK(
|
||||
symm_mem->has_multicast_support(),
|
||||
"multimem_one_shot_all_reduce: requires multicast support.");
|
||||
|
||||
auto output = at::empty_like(input);
|
||||
|
||||
const size_t alignment =
|
||||
get_and_verify_alignment(input, "multimem_one_shot_all_reduce");
|
||||
|
||||
int num_blocks = 0, num_threads = 0;
|
||||
init_elementwise_launch_config(
|
||||
input.numel(),
|
||||
input.element_size(),
|
||||
alignment,
|
||||
1,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
|
||||
#define DISPATCH(scalar_t, kernel_alignment) \
|
||||
if (alignment == kernel_alignment) { \
|
||||
multimem_one_shot_all_reduce_kernel<scalar_t, kernel_alignment> \
|
||||
<<<num_blocks, num_threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
|
||||
reinterpret_cast<scalar_t*>(symm_mem->get_multicast_ptr()) + \
|
||||
input.storage_offset(), \
|
||||
output.data_ptr<scalar_t>(), \
|
||||
input.numel(), \
|
||||
reinterpret_cast<uint32_t**>(symm_mem->get_signal_pad_ptrs_dev()), \
|
||||
symm_mem->get_rank(), \
|
||||
symm_mem->get_world_size()); \
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
||||
}
|
||||
|
||||
AT_DISPATCH_SWITCH(
|
||||
input.scalar_type(),
|
||||
"multimem_all_reduce",
|
||||
AT_DISPATCH_CASE(at::kBFloat16, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}) AT_DISPATCH_CASE(at::kFloat, [&] {
|
||||
DISPATCH(scalar_t, 16);
|
||||
DISPATCH(scalar_t, 8);
|
||||
DISPATCH(scalar_t, 4);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
m.def(
|
||||
"multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
|
||||
m.def(
|
||||
"multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
|
||||
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce),
|
||||
{at::Tag::pt2_compliant_tag});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
@ -176,7 +176,7 @@ at::Tensor empty_strided_p2p(
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
const at::Tensor& tensor) {
|
||||
auto allocator = get_allocator(tensor.device().type());
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
return allocator->rendezvous(tensor.storage().data_ptr().get());
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
@ -189,5 +189,9 @@ c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
return allocator->rendezvous(tensor.data_ptr());
|
||||
}
|
||||
|
||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type) {
|
||||
auto allocator = get_allocator(device_type);
|
||||
return allocator->has_multicast_support();
|
||||
}
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
||||
|
@ -51,6 +51,9 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
|
||||
virtual size_t get_buffer_size() = 0;
|
||||
virtual size_t get_signal_pad_size() = 0;
|
||||
|
||||
virtual bool has_multicast_support() = 0;
|
||||
virtual void* get_multicast_ptr() = 0;
|
||||
|
||||
virtual at::Tensor get_buffer(
|
||||
int rank,
|
||||
c10::IntArrayRef sizes,
|
||||
@ -78,6 +81,7 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
|
||||
virtual size_t get_alloc_size(void* ptr) = 0;
|
||||
virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
|
||||
virtual bool is_rendezvous_completed(void* ptr) = 0;
|
||||
virtual bool has_multicast_support() = 0;
|
||||
};
|
||||
|
||||
C10_EXPORT bool is_finalizing();
|
||||
@ -150,5 +154,6 @@ TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
||||
const at::Tensor& tensor);
|
||||
|
||||
TORCH_API bool has_multicast_support(c10::DeviceType device_type);
|
||||
} // namespace symmetric_memory
|
||||
} // namespace c10d
|
||||
|
@ -1044,6 +1044,9 @@ This class does not support ``__members__`` property.)");
|
||||
.def_static(
|
||||
"get_symmetric_memory",
|
||||
&::c10d::symmetric_memory::get_symmetric_memory)
|
||||
.def_static(
|
||||
"has_multicast_support",
|
||||
&::c10d::symmetric_memory::has_multicast_support)
|
||||
.def_property_readonly("rank", &SymmetricMemory::get_rank)
|
||||
.def_property_readonly("world_size", &SymmetricMemory::get_world_size)
|
||||
.def_property_readonly(
|
||||
|
Reference in New Issue
Block a user