mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`. Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc). Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA` (in-house impl), `NVSHMEM`. The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO). Ported most code from #146593. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151261 Approved by: https://github.com/fegin, https://github.com/fduwjj
338 lines
11 KiB
C++
338 lines
11 KiB
C++
#pragma once
|
|
|
|
#include <atomic>
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && CUDART_VERSION >= 12010
|
|
#define NVCC_SUPPORTS_MULTICAST 1
|
|
#endif
|
|
|
|
#include <ATen/ATen.h>
|
|
#if !defined(USE_ROCM)
|
|
#include <cuda_bf16.h>
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
|
#include <cuda/atomic>
|
|
#endif
|
|
#endif
|
|
#include <ATen/native/cuda/MemoryAccess.cuh>
|
|
|
|
namespace c10d::symmetric_memory {
|
|
|
|
template <int Size>
|
|
using Vec = at::native::memory::Vec<Size>;
|
|
|
|
template <class... T>
|
|
inline constexpr bool dependent_false =
|
|
at::native::memory::dependent_false<T...>;
|
|
|
|
using at::native::memory::get_alignment;
|
|
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ uint32_t
|
|
cas(uint32_t* addr, uint32_t compare, uint32_t val) {
|
|
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
|
::cuda::atomic_ref<uint32_t, ::cuda::thread_scope_system> ref(*addr);
|
|
ref.compare_exchange_strong(compare, val, ::cuda::std::memory_order(Sem));
|
|
return compare;
|
|
#else
|
|
CUDA_KERNEL_ASSERT(false);
|
|
return 0;
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ void trap() {
|
|
#if defined(USE_ROCM)
|
|
assert(0);
|
|
#else
|
|
__trap();
|
|
#endif
|
|
}
|
|
|
|
__device__ __forceinline__ size_t global_timer_ns() {
|
|
#if defined(USE_ROCM)
|
|
CUDA_KERNEL_ASSERT(false);
|
|
return 0;
|
|
#else
|
|
size_t val;
|
|
asm volatile("mov.u64 %0, %globaltimer;" : "=l"(val) : : "memory");
|
|
return val;
|
|
#endif
|
|
}
|
|
|
|
constexpr size_t ns_per_ms = 1e6;
|
|
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ bool try_put_signal(
|
|
uint32_t* addr,
|
|
size_t timeout_ms) {
|
|
size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms;
|
|
while (cas<Sem>(addr, 0, 1) != 0) {
|
|
if (timeout_ms != 0 && global_timer_ns() > deadline) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ bool try_wait_signal(
|
|
uint32_t* addr,
|
|
size_t timeout_ms) {
|
|
size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms;
|
|
while (cas<Sem>(addr, 1, 0) != 1) {
|
|
if (timeout_ms != 0 && global_timer_ns() > deadline) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ void put_signal(uint32_t* addr) {
|
|
while (cas<Sem>(addr, 0, 1) != 0)
|
|
;
|
|
}
|
|
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ void wait_signal(uint32_t* addr) {
|
|
while (cas<Sem>(addr, 1, 0) != 1)
|
|
;
|
|
}
|
|
|
|
// Synchronizes blocks with matching blockIdx across participating devices.
|
|
// Note: sync_remote_block itself is not a system level barrier/fence. It is a
|
|
// building block for expressing different synchronization patterns.
|
|
//
|
|
// Pattern 0: Ensures that all writes to symm_mem buffers from previous
|
|
// kernels across all devices are visible to the current kernel:
|
|
//
|
|
// sync_remote_blocks<std::memory_order_relaxed>(...);
|
|
// __syncthreads();
|
|
//
|
|
// Pattern 1: Ensures that all writes to symm_mem buffers from the current
|
|
// block are visible to all remote blocks with matching blockIdx:
|
|
//
|
|
// __syncthreads();
|
|
// sync_remote_blocks<std::memory_order_acq_rel>(...);
|
|
// __syncthreads();
|
|
//
|
|
// Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
|
|
// for writing by subsequent kernels across all devices.
|
|
//
|
|
// __syncthreads();
|
|
// sync_remote_blocks<std::memory_order_relaxed>(...);
|
|
template <std::memory_order Sem>
|
|
__device__ __forceinline__ void sync_remote_blocks(
|
|
uint32_t** signal_pads,
|
|
size_t rank,
|
|
size_t world_size);
|
|
|
|
template <>
|
|
__device__ __forceinline__ void sync_remote_blocks<std::memory_order_relaxed>(
|
|
uint32_t** signal_pads,
|
|
size_t rank,
|
|
size_t world_size) {
|
|
if (threadIdx.x < world_size) {
|
|
auto target_rank = threadIdx.x;
|
|
put_signal<std::memory_order_relaxed>(
|
|
signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
|
wait_signal<std::memory_order_relaxed>(
|
|
signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
__device__ __forceinline__ void sync_remote_blocks<std::memory_order_acq_rel>(
|
|
uint32_t** signal_pads,
|
|
size_t rank,
|
|
size_t world_size) {
|
|
if (threadIdx.x < world_size) {
|
|
auto target_rank = threadIdx.x;
|
|
put_signal<std::memory_order_release>(
|
|
signal_pads[target_rank] + blockIdx.x * world_size + rank);
|
|
wait_signal<std::memory_order_acquire>(
|
|
signal_pads[rank] + blockIdx.x * world_size + target_rank);
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
struct MultimemLdReduce {
|
|
template <int Alignment>
|
|
__device__ __inline__ Vec<Alignment> operator()(T* mc_ptr) {
|
|
static_assert(dependent_false<T>);
|
|
}
|
|
};
|
|
|
|
template <int Alignment, typename T>
|
|
__device__ __inline__ Vec<Alignment> multimem_ld_reduce_add(T* mc_ptr) {
|
|
MultimemLdReduce<T> functor;
|
|
return functor.template operator()<Alignment>(mc_ptr);
|
|
}
|
|
|
|
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
|
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type, acc_prec) \
|
|
template <> \
|
|
struct MultimemLdReduce<type> { \
|
|
template <int Alignment> \
|
|
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
|
CUDA_KERNEL_ASSERT(false); \
|
|
} \
|
|
};
|
|
#else
|
|
#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type, acc_prec) \
|
|
template <> \
|
|
struct MultimemLdReduce<type> { \
|
|
template <int Alignment> \
|
|
__device__ __inline__ Vec<Alignment> operator()(type* mc_ptr) { \
|
|
Vec<Alignment> vec; \
|
|
if constexpr (Alignment == 16) { \
|
|
asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec \
|
|
".v4" asm_type " {%0,%1,%2,%3}, [%4];" \
|
|
: "=r"(vec.u32[0]), \
|
|
"=r"(vec.u32[1]), \
|
|
"=r"(vec.u32[2]), \
|
|
"=r"(vec.u32[3]) \
|
|
: "l"(mc_ptr) \
|
|
: "memory"); \
|
|
} else if constexpr (Alignment == 8) { \
|
|
asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec \
|
|
".v2" asm_type " {%0,%1}, [%2];" \
|
|
: "=r"(vec.u32[0]), "=r"(vec.u32[1]) \
|
|
: "l"(mc_ptr) \
|
|
: "memory"); \
|
|
} else if constexpr (Alignment == 4) { \
|
|
asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec asm_type \
|
|
" %0, [%1];" \
|
|
: "=r"(vec.u32) \
|
|
: "l"(mc_ptr) \
|
|
: "memory"); \
|
|
} \
|
|
return vec; \
|
|
} \
|
|
};
|
|
#endif
|
|
|
|
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, ".bf16x2", ".acc::f32");
|
|
SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, ".f32", "");
|
|
|
|
template <int Alignment, typename T>
|
|
__device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
|
|
#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST)
|
|
CUDA_KERNEL_ASSERT(false);
|
|
#else
|
|
if constexpr (Alignment == 16) {
|
|
asm("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};"
|
|
:
|
|
: "l"(mc_ptr),
|
|
"r"(vec.u32[0]),
|
|
"r"(vec.u32[1]),
|
|
"r"(vec.u32[2]),
|
|
"r"(vec.u32[3])
|
|
: "memory");
|
|
} else if constexpr (Alignment == 8) {
|
|
asm("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};"
|
|
:
|
|
: "l"(mc_ptr), "r"(vec.u32[0]), "r"(vec.u32[1])
|
|
: "memory");
|
|
} else if constexpr (Alignment == 4) {
|
|
asm("multimem.st.relaxed.sys.global.f32 [%0], %1;"
|
|
:
|
|
: "l"(mc_ptr), "r"(vec.u32)
|
|
: "memory");
|
|
} else {
|
|
static_assert(dependent_false<T>);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
#if defined(USE_ROCM)
|
|
using __nv_bfloat162 = uint32_t;
|
|
#endif
|
|
|
|
template <typename T>
|
|
__device__ __inline__ T add_bf16x2(T a, T b) {
|
|
static_assert(sizeof(T) == 4);
|
|
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
|
CUDA_KERNEL_ASSERT(false);
|
|
return T{};
|
|
#else
|
|
auto res = __hadd2(
|
|
*reinterpret_cast<__nv_bfloat162*>(&a),
|
|
*reinterpret_cast<__nv_bfloat162*>(&b));
|
|
return *reinterpret_cast<T*>(&res);
|
|
#endif
|
|
}
|
|
|
|
template <int Alignment, typename T>
|
|
__device__ __inline__ Vec<Alignment> add_vec(
|
|
const Vec<Alignment>& a,
|
|
const Vec<Alignment>& b) {
|
|
Vec<Alignment> c{};
|
|
if constexpr (std::is_same_v<T, float>) {
|
|
if constexpr (Alignment == 16) {
|
|
c.f32[0] = a.f32[0] + b.f32[0];
|
|
c.f32[1] = a.f32[1] + b.f32[1];
|
|
c.f32[2] = a.f32[2] + b.f32[2];
|
|
c.f32[3] = a.f32[3] + b.f32[3];
|
|
} else if constexpr (Alignment == 8) {
|
|
c.f32[0] = a.f32[0] + b.f32[0];
|
|
c.f32[1] = a.f32[1] + b.f32[1];
|
|
} else if constexpr (Alignment == 4) {
|
|
c.f32 = a.f32 + b.f32;
|
|
} else {
|
|
static_assert(dependent_false<T>);
|
|
}
|
|
} else if constexpr (std::is_same_v<T, at::BFloat16>) {
|
|
if constexpr (Alignment == 16) {
|
|
c.u32[0] = add_bf16x2(a.u32[0], b.u32[0]);
|
|
c.u32[1] = add_bf16x2(a.u32[1], b.u32[1]);
|
|
c.u32[2] = add_bf16x2(a.u32[2], b.u32[2]);
|
|
c.u32[3] = add_bf16x2(a.u32[3], b.u32[3]);
|
|
} else if constexpr (Alignment == 8) {
|
|
c.u32[0] = add_bf16x2(a.u32[0], b.u32[0]);
|
|
c.u32[1] = add_bf16x2(a.u32[1], b.u32[1]);
|
|
} else if constexpr (Alignment == 4) {
|
|
c.u32 = add_bf16x2(a.u32, b.u32);
|
|
} else {
|
|
static_assert(dependent_false<T>);
|
|
}
|
|
} else {
|
|
static_assert(dependent_false<T>);
|
|
}
|
|
return c;
|
|
}
|
|
|
|
// With world_size specialization: perform balanced load from all peers before
|
|
// performing reduction.
|
|
template <typename T, int alignment, int k_world_size>
|
|
__device__ inline std::enable_if_t<(k_world_size > 0), Vec<alignment>>
|
|
load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) {
|
|
Vec<alignment> vecs[k_world_size];
|
|
#pragma unroll k_world_size
|
|
for (size_t step = 0; step < k_world_size; ++step) {
|
|
size_t remote_rank = (rank + step) % k_world_size;
|
|
vecs[remote_rank] =
|
|
at::native::memory::ld_vec<alignment>(ptrs[remote_rank] + offset);
|
|
}
|
|
auto acc = vecs[0];
|
|
#pragma unroll k_world_size - 1
|
|
for (size_t r = 1; r < world_size; ++r) {
|
|
acc = add_vec<alignment, T>(acc, vecs[r]);
|
|
}
|
|
return acc;
|
|
}
|
|
|
|
// Without world_size specialization: perform ordered (unbalanced) load and
|
|
// accumulate on each load.
|
|
template <typename T, int alignment, int k_world_size>
|
|
__device__ inline std::enable_if_t<(k_world_size <= 0), Vec<alignment>>
|
|
load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) {
|
|
Vec<alignment> acc{};
|
|
for (size_t step = 0; step < world_size; ++step) {
|
|
auto vec = at::native::memory::ld_vec<alignment>(ptrs[step] + offset);
|
|
acc = add_vec<alignment, T>(acc, vec);
|
|
}
|
|
return acc;
|
|
}
|
|
|
|
} // namespace c10d::symmetric_memory
|