Reland fast gather and index implementation (#151917)

This PR reapplies #151490 and #151753 together, and adds some missing checks when applying the fast path.
Previously missed checks:
1) indexing path has the stride in the indexed dimension in bytes, gather path has the stride in the indexed dimension in elements. When checking if fast path is applicable, I didn't take this difference into account, and still multiplied the indexing stride by element size. Fixed and test added
2) We want to take fast path only when we are copying contiguous equally spaced slices of inputs + all the necessary alignment requirements. The effective tensor size should be 2d (after all possible flattening is applied), the index stride in the last dimension should be 0, and, since in the kernel we are not applying non-indexing-related offsets to src tensor, the src tensor stride in the second dimension should be 0. This automatically happens for gather with dim=0, so I didn't put in an explicit condition for this. Sometimes all conditions except first dim "effective" stride equal to 0 are satisfied for scatter on non-zero dim, when index size in the indexing dimension is 1 and thus it is collapsed (dimensions of size 1 are always collapsed), e.g.
```
        # test gather along 1st dim that can accidentally trigger fast path
        # because due to index dimension in the gather dim being 1
        # an unexpected squashing in tensorIterator happens
        src = make_tensor((16, 2, 16), device=device, dtype=dtype)
        ind = torch.randint(2, (16, 1), device=device).view(16, 1, 1).expand(16, 1, 16)
        res = torch.gather(src, dim=1, index=ind)
        if res.device.type == "cuda":
            ref_cpu = torch.gather(src.cpu(), dim=1, index=ind.cpu())
            self.assertEqual(res.cpu(), ref_cpu, atol=0, rtol=0)
```
Note that if index size here was (16, 2, 16) instead of (16, 1, 16) then the middle dimension could not be collapsed and we wouldn't end up incorrectly taking fast path.
We could update the kernel to take this stride into account when computing offsets into src tensor, or we could specifically disallow non-zero stride on the first dimension. I took the second path for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151917
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/Skylion007
This commit is contained in:
Natalia Gimelshein
2025-04-23 19:13:13 +00:00
committed by PyTorch MergeBot
parent 69e41cee04
commit 99ae7d4069
8 changed files with 326 additions and 129 deletions

View File

@ -14,6 +14,8 @@
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/native/quantized/IndexKernel.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/IndexKernelUtils.h>
#include <c10/core/Scalar.h>
@ -52,7 +54,7 @@ static void launch_kernel(const int64_t N, const func_t& f) {
}
template <typename func_t>
void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f) {
void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride, const func_t& f, const bool is_gather_like) {
const auto num_indices = index_size.size();
AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(static_cast<int64_t>(num_indices) == iter.ntensors() - 2);
@ -63,11 +65,31 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_index_kernel(sub_iter, index_size, index_stride, f);
gpu_index_kernel(sub_iter, index_size, index_stride, f, is_gather_like);
}
return;
}
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
if (is_gather_like && num_indices==1) {
const size_t element_size = iter.element_size(0);
constexpr size_t alignment = 16;
if (at::native::fast_gather_kernel_eligible<alignment>(iter, out_ptr, in_ptr, index_stride[0], element_size)) {
auto slice_size = iter.shape()[0] * element_size;
auto num_ind = iter.shape()[1];
auto ind_dim_size = index_size[0];
auto inp_stride_bytes = index_stride[0];
auto out_stride_bytes = iter.strides(0)[1];
if (iter.numel() == 0) return;
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
return;
}
}
auto sizes = std::array<int64_t, MAX_DIMS>{};
auto strides = std::array<int64_t, MAX_DIMS>{};
auto index_ptrs = std::array<char*, MAX_DIMS>{};
@ -77,8 +99,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
index_ptrs[i] = (char*)iter.data_ptr(i + 2);
}
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
auto offset_calc = make_offset_calculator<3>(iter);
launch_kernel<launch_size_nd, launch_bound2>(iter.numel(), [=]__device__(int idx) {
@ -183,14 +203,14 @@ template <typename scalar_t>
void index_kernel_impl(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
*reinterpret_cast<scalar_t*>(out_data) = *reinterpret_cast<const scalar_t*>(in_data + offset);
});
}, true);
}
template <typename scalar_t>
void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
*reinterpret_cast<scalar_t*>(out_data + offset) = *reinterpret_cast<const scalar_t*>(in_data);
});
}, false);
}
static void index_kernel(
@ -280,7 +300,7 @@ void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef ind
// The replacement should generate the same PTX as std::clamp. See https://godbolt.org/z/Wde9KW3v4
qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
*(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
});
}, false);
});
}

View File

@ -0,0 +1,44 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/ceil_div.h>
namespace at::native {
template <int Alignment>
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
int64_t ind = idx[blockIdx.x];
if (allow_neg_indices) {
ind = (ind < 0) ? ind + ind_dim_size : ind;
}
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
if (off >= slice_size) return;
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
}
template <int64_t Alignment>
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
constexpr int64_t max_num_threads=256;
auto num_threads = at::round_up(
at::ceil_div(slice_size_in_bytes, Alignment),
static_cast<int64_t>(C10_WARP_SIZE));
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
auto block = std::min(max_num_threads, num_threads);
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// explicit template instantiation
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
}

View File

@ -0,0 +1,35 @@
#include <cstdint>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
namespace at::native {
template<int alignment>
inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const out_ptr, char * const in_ptr, const size_t index_stride_bytes, const size_t element_size) {
using at::native::memory::get_alignment;
const auto index_element_size = iter.element_size(2);
//TensorIterator strides and sizes are ordered fastest moving to slowest moving,
//in contrast to regular sizes
// we need contiguous source and dst slices and aligned pointers and strides and slice size to do vectorized loads
// also we need idx to be expanded in the last dimension so we can copy entire slices
// and we need the src tensor to keep 0 stride from restriding
// (it could have been deleted by dimension collapse, in this case iterator would still be 2d
// but we cannot use fast path)
return iter.ndim() == 2 && iter.strides(2)[0]==0 && iter.strides(2)[1]==index_element_size &&
static_cast<size_t>(iter.strides(0)[0])==element_size &&
static_cast<size_t>(iter.strides(1)[0])==element_size && static_cast<size_t>(iter.strides(1)[1] == 0) &&
get_alignment(out_ptr) == alignment && get_alignment(in_ptr) == alignment &&
get_alignment(static_cast<size_t>(iter.shape()[0] * element_size)) == alignment &&
get_alignment(static_cast<size_t>(index_stride_bytes)) == alignment &&
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
}
template <int64_t Alignment>
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
bool allow_neg_indices=false);
}

View File

@ -536,4 +536,123 @@ inline int can_vectorize_up_to(array_t pointers) {
return result;
}
template <typename T>
__inline__ size_t get_alignment(T ptr_or_size) {
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
if (val % 16 == 0) {
return 16;
} else if (val % 8 == 0) {
return 8;
} else if (val % 4 == 0) {
return 4;
} else if (val % 2 == 0) {
return 2;
} else {
return 1;
}
}
template <>
__inline__ size_t get_alignment<size_t>(size_t size) {
return get_alignment(reinterpret_cast<void*>(size));
}
template <bool Value, class... Args>
inline constexpr bool dependent_bool_value = Value;
template <class... Args>
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
template <int Size>
union Vec;
template <>
union Vec<4> {
uint16_t u16[2];
uint32_t u32, as_scalar;
float f32;
};
template <>
union Vec<8> {
uint16_t u16[4];
uint32_t u32[2];
uint64_t u64, as_scalar;
float f32[2];
};
template <>
union alignas(16) Vec<16> {
uint16_t u16[8];
uint32_t u32[4];
uint64_t u64[2];
uint4 u128, as_scalar;
float f32[4];
};
template <int Alignment, typename T>
__device__ __inline__ Vec<Alignment> ld_vec(const T* addr) {
Vec<Alignment> vec;
if constexpr (Alignment == 16) {
#if defined(USE_ROCM)
vec.u128 = *reinterpret_cast<const uint4*>(addr);
} else if constexpr (Alignment == 8) {
vec.u64 = *reinterpret_cast<const uint64_t*>(addr);
} else if constexpr (Alignment == 4) {
vec.u32 = *reinterpret_cast<const uint32_t*>(addr);
#else
asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3])
: "l"(addr)
: "memory");
} else if constexpr (Alignment == 8) {
asm("ld.global.v2.u32 {%0,%1}, [%2];"
: "=r"(vec.u32[0]), "=r"(vec.u32[1])
: "l"(addr)
: "memory");
} else if constexpr (Alignment == 4) {
asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory");
#endif
} else {
static_assert(dependent_false<T>);
}
return vec;
}
template <int Alignment, typename T>
__device__ __inline__ void st_vec(T* addr, const Vec<Alignment>& vec) {
if constexpr (Alignment == 16) {
#if defined(USE_ROCM)
reinterpret_cast<uint64_t*>(addr)[0] = vec.u64[0];
reinterpret_cast<uint64_t*>(addr)[1] = vec.u64[1];
} else if constexpr (Alignment == 8) {
*reinterpret_cast<uint64_t*>(addr) = vec.u64;
} else if constexpr (Alignment == 4) {
*reinterpret_cast<uint32_t*>(addr) = vec.u32;
#else
asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(addr),
"r"(vec.u32[0]),
"r"(vec.u32[1]),
"r"(vec.u32[2]),
"r"(vec.u32[3])
: "memory");
} else if constexpr (Alignment == 8) {
asm("st.global.v2.u32 [%0], {%1,%2};"
:
: "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1])
: "memory");
} else if constexpr (Alignment == 4) {
asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory");
#endif
} else {
static_assert(dependent_false<T>);
}
}
} // namespace at::native::memory

View File

@ -1,16 +1,16 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ceil_div.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/ScatterGatherChecks.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/IndexKernelUtils.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
@ -116,7 +116,6 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <bool is_scatter_like, typename scalar_t>
struct _cuda_scatter_gather_internal_kernel {
template <typename func_t>
@ -140,13 +139,29 @@ struct _cuda_scatter_gather_internal_kernel {
char* src_ptr = (char*)iter.data_ptr(1);
char* index_ptr = (char*)iter.data_ptr(2);
if constexpr (!is_scatter_like) {
// we can go to faster path if we are indexing on the first dim
// the dst and src are contiguous and all the dims and pts are multiple of 16
constexpr size_t element_size = sizeof(scalar_t);
constexpr size_t alignment = 16;
if (at::native::fast_gather_kernel_eligible<alignment>(iter, self_ptr, src_ptr, index_stride * element_size, element_size)) {
auto slice_size = iter.shape()[0] * element_size;
auto num_ind = iter.shape()[1];
auto ind_dim_size = index_size;
auto inp_stride_bytes = index_stride * element_size;
auto out_stride_bytes = iter.strides(0)[1];
if (iter.numel() == 0) return;
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
return;
}
}
auto offset_calc = make_offset_calculator<3>(iter);
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds");
&& "scatter gather kernel index out of bounds");
f(
(scalar_t*)(self_ptr + offsets[0]),
@ -157,6 +172,7 @@ struct _cuda_scatter_gather_internal_kernel {
};
_launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
}
}; // struct _cuda_scatter_gather_internal_kernel

View File

@ -65,6 +65,63 @@ class TestScatterGather(TestCase):
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, atol=0, rtol=0)
@dtypes(torch.int8, torch.bfloat16)
def test_gather_large(self, device, dtype):
# test larger shapes to check vectorized implementation
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)):
src = make_tensor((m, k), device=device, dtype=dtype)
alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype)
discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src)
alloc1 = torch.empty(src.nelement() + 1, device=device, dtype=dtype)
misaligned = alloc1[1:].view(m, k).copy_(src)
alloc2 = torch.empty(m, k + 4, device=device, dtype=dtype)
misaligned1 = alloc2[:, :-4].copy_(src)
num_ind = n
for dim in (0, 1):
max_ind = src.shape[dim]
ind0 = torch.randint(max_ind, (num_ind,), device=device)
ind_discontig0 = torch.empty(num_ind * 2, device=device, dtype=torch.int64)[::2].copy_(ind0)
shape_ind = [1] * src.ndim
shape_ind[dim] = ind0.shape[0]
shape_out = list(src.shape)
shape_out[dim] = ind0.shape[0]
ind = ind0.view(shape_ind).expand(shape_out)
ind_discontig = ind_discontig0.view(shape_ind).expand(shape_out)
res = torch.gather(src, dim=dim, index=ind)
ref = src[ind0] if dim == 0 else src[:, ind0]
self.assertEqual(res, ref, atol=0, rtol=0)
if res.device.type == "cuda":
ref_cpu = src.cpu()[ind0.cpu()] if dim == 0 else src.cpu()[:, ind0.cpu()]
self.assertEqual(res.cpu(), ref_cpu, atol=0, rtol=0)
res = torch.gather(src, dim=dim, index=ind_discontig)
self.assertEqual(res, ref, atol=0, rtol=0)
res_ind = src[ind_discontig0] if dim == 0 else src[:, ind_discontig0]
self.assertEqual(res_ind, ref, atol=0, rtol=0)
res_ind_neg = src[ind0 - src.shape[dim]] if dim == 0 else src[:, ind0 - src.shape[1]]
self.assertEqual(res_ind_neg, ref, atol=0, rtol=0)
res = torch.gather(discontig, dim=dim, index=ind)
self.assertEqual(res, ref, atol=0, rtol=0)
res_ind = discontig[ind0] if dim == 0 else discontig[:, ind0]
self.assertEqual(res_ind, ref, atol=0, rtol=0)
res = torch.gather(misaligned, dim=dim, index=ind)
self.assertEqual(res, ref, atol=0, rtol=0)
res_ind = misaligned[ind0] if dim == 0 else misaligned[:, ind0]
self.assertEqual(res_ind, ref, atol=0, rtol=0)
res_ind = misaligned1[ind0] if dim == 0 else misaligned[:, ind0]
self.assertEqual(res_ind, ref, atol=0, rtol=0)
res_gather = torch.gather(misaligned1, dim=dim, index=ind)
self.assertEqual(res_gather, ref, atol=0, rtol=0)
# test gather along 1st dim that can accidentally trigger fast path
# because due to index dimension in the gather dim being 1
# an unexpected squashing in tensorIterator happens
src = make_tensor((16, 2, 16), device=device, dtype=dtype)
ind = torch.randint(2, (16, 1), device=device).view(16, 1, 1).expand(16, 1, 16)
res = torch.gather(src, dim=1, index=ind)
if res.device.type == "cuda":
ref_cpu = torch.gather(src.cpu(), dim=1, index=ind.cpu())
self.assertEqual(res.cpu(), ref_cpu, atol=0, rtol=0)
@dtypes(torch.bool)
def test_gather_bool(self, device, dtype):
src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)

View File

@ -13,34 +13,18 @@
#include <cuda/atomic>
#endif
#endif
#include <ATen/native/cuda/MemoryAccess.cuh>
namespace c10d::symmetric_memory {
template <typename T>
__inline__ size_t get_alignment(T ptr_or_size) {
auto val = reinterpret_cast<uintptr_t>(ptr_or_size);
if (val % 16 == 0) {
return 16;
} else if (val % 8 == 0) {
return 8;
} else if (val % 4 == 0) {
return 4;
} else if (val % 2 == 0) {
return 2;
} else {
return 1;
}
}
template <int Size>
using Vec = at::native::memory::Vec<Size>;
template <>
__inline__ size_t get_alignment<size_t>(size_t size) {
return get_alignment(reinterpret_cast<void*>(size));
}
template <class... T>
inline constexpr bool dependent_false =
at::native::memory::dependent_false<T...>;
template <bool Value, class... Args>
inline constexpr bool dependent_bool_value = Value;
template <class... Args>
inline constexpr bool dependent_false = dependent_bool_value<false, Args...>;
using at::native::memory::get_alignment;
template <std::memory_order Sem>
__device__ __forceinline__ uint32_t
@ -170,33 +154,6 @@ __device__ __forceinline__ void sync_remote_blocks<std::memory_order_acq_rel>(
}
}
template <int Size>
union Vec;
template <>
union Vec<4> {
uint16_t u16[2];
uint32_t u32, as_scalar;
float f32;
};
template <>
union Vec<8> {
uint16_t u16[4];
uint32_t u32[2];
uint64_t u64, as_scalar;
float f32[2];
};
template <>
union alignas(16) Vec<16> {
uint16_t u16[8];
uint32_t u32[4];
uint64_t u64[2];
uint4 u128, as_scalar;
float f32[4];
};
template <typename T>
struct MultimemLdReduce {
template <int Alignment>
@ -287,58 +244,6 @@ __device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
#endif
}
template <int Alignment, typename T>
__device__ __inline__ Vec<Alignment> ld_vec(const T* addr) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
Vec<Alignment> vec;
if constexpr (Alignment == 16) {
asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3])
: "l"(addr)
: "memory");
} else if constexpr (Alignment == 8) {
asm("ld.global.v2.u32 {%0,%1}, [%2];"
: "=r"(vec.u32[0]), "=r"(vec.u32[1])
: "l"(addr)
: "memory");
} else if constexpr (Alignment == 4) {
asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory");
} else {
static_assert(dependent_false<T>);
}
return vec;
#endif
}
template <int Alignment, typename T>
__device__ __inline__ void st_vec(T* addr, const Vec<Alignment>& vec) {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
#else
if constexpr (Alignment == 16) {
asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};"
:
: "l"(addr),
"r"(vec.u32[0]),
"r"(vec.u32[1]),
"r"(vec.u32[2]),
"r"(vec.u32[3])
: "memory");
} else if constexpr (Alignment == 8) {
asm("st.global.v2.u32 [%0], {%1,%2};"
:
: "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1])
: "memory");
} else if constexpr (Alignment == 4) {
asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory");
} else {
static_assert(dependent_false<T>);
}
#endif
}
#if defined(USE_ROCM)
using __nv_bfloat162 = uint32_t;
#endif
@ -405,7 +310,8 @@ load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) {
#pragma unroll k_world_size
for (size_t step = 0; step < k_world_size; ++step) {
size_t remote_rank = (rank + step) % k_world_size;
vecs[remote_rank] = ld_vec<alignment>(ptrs[remote_rank] + offset);
vecs[remote_rank] =
at::native::memory::ld_vec<alignment>(ptrs[remote_rank] + offset);
}
auto acc = vecs[0];
#pragma unroll k_world_size - 1
@ -422,7 +328,7 @@ __device__ inline std::enable_if_t<(k_world_size <= 0), Vec<alignment>>
load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) {
Vec<alignment> acc{};
for (size_t step = 0; step < world_size; ++step) {
auto vec = ld_vec<alignment>(ptrs[step] + offset);
auto vec = at::native::memory::ld_vec<alignment>(ptrs[step] + offset);
acc = add_vec<alignment, T>(acc, vec);
}
return acc;

View File

@ -73,7 +73,7 @@ size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
const size_t min_alignment = std::max(4l, input.element_size());
// Only check the offset since the multicast address is always at least
// 128-bit aligned
const size_t ptr_alignment = get_alignment(
const size_t ptr_alignment = at::native::memory::get_alignment(
static_cast<size_t>(input.storage_offset() * input.element_size()));
TORCH_CHECK(
ptr_alignment >= min_alignment,
@ -85,7 +85,7 @@ size_t get_and_verify_alignment(const at::Tensor& input, const char* op_name) {
"-byte aligned.");
const size_t size_alignment =
get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
at::native::memory::get_alignment(static_cast<size_t>(input.numel() * input.element_size()));
TORCH_CHECK(
size_alignment >= min_alignment,
op_name,
@ -226,7 +226,7 @@ static __global__ void multimem_one_shot_all_reduce_kernel(
auto stride = blockDim.x * gridDim.x * numel_per_thread;
for (size_t i = offset; i < numel; i += stride) {
auto vec = multimem_ld_reduce_add<alignment>(input_mc_ptr + i);
st_vec<alignment>(output_ptr + i, vec);
at::native::memory::st_vec<alignment>(output_ptr + i, vec);
}
__syncthreads();
@ -319,7 +319,7 @@ static __global__ void multimem_all_gather_kernel(
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * alignment;
auto stride = blockDim.x * gridDim.x * alignment;
for (size_t i = offset; i < bytes_per_rank; i += stride) {
auto vec = ld_vec<alignment>(input_ptr + i);
auto vec = at::native::memory::ld_vec<alignment>(input_ptr + i);
multimem_st<alignment>(output_mc_ptr + start + i, vec);
}
@ -419,8 +419,8 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
auto stride = blockDim.x * gridDim.x * numel_per_thread;
if (input_ptr) {
for (size_t i = offset; i < numel; i += stride) {
Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
Vec<alignment> vec_st = at::native::memory::ld_vec<alignment>(input_ptr + i);
at::native::memory::st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
}
}
// TODO make it sync with one block for no-copy case
@ -430,7 +430,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
for (size_t i = offset; i < numel; i += stride) {
auto vec = load_and_reduce<T, alignment, k_world_size>(
input_ptrs, rank, world_size, input_offset + i);
st_vec<alignment>(output_ptr + i, vec);
at::native::memory::st_vec<alignment>(output_ptr + i, vec);
}
__syncthreads();
@ -607,9 +607,9 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
input_ptrs, rank, world_size, input_offset + start + idx);
// store to local buffer or to output
if constexpr (reduce_scatter) {
st_vec<alignment>(output_ptr + i, vec);
at::native::memory::st_vec<alignment>(output_ptr + i, vec);
} else {
st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
at::native::memory::st_vec<alignment>(input_ptrs[rank] + input_offset + start + i, vec);
}
}
@ -628,7 +628,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
if (remote_start + i >= numel) {
continue;
}
tmp[step] = ld_vec<alignment>(
tmp[step] = at::native::memory::ld_vec<alignment>(
input_ptrs[remote_rank] + input_offset + remote_start + i);
}
#pragma unroll k_world_size
@ -638,7 +638,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
if (remote_start + i >= numel) {
continue;
}
st_vec<alignment>(output_ptr + remote_start + i, tmp[step]);
at::native::memory::st_vec<alignment>(output_ptr + remote_start + i, tmp[step]);
}
}
// need to make sure all blocks exit simultaneously so that the data
@ -676,7 +676,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__
input_ptrs, rank, world_size, input_offset + start + i);
for (size_t step = 0; step < world_size; ++step) {
size_t remote_rank = (rank + step) % world_size;
st_vec<alignment>(
at::native::memory::st_vec<alignment>(
input_ptrs[remote_rank] + input_offset + start + i, vec);
}
}