[ATen][CUDA] Optimize 128 bit vectorization (#148320)

Fixes #147376.
As per request: https://github.com/pytorch/pytorch/pull/145746#pullrequestreview-2642118301
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
This commit is contained in:
Aidyn-A
2025-05-02 17:35:44 +00:00
committed by PyTorch MergeBot
parent 3baa85cfad
commit 72337bdcf2

View File

@ -158,6 +158,69 @@ constexpr auto calc_io_size(){
#endif
}
#ifndef USE_ROCM
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
// used on sm_90 and sm_100 exclusively.
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
if constexpr (vec_size == 8) {
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<
array_t,
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
} else {
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
auto policy = memory::policies::unroll<
array_t,
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
}
#else // USE_ROCM
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
@ -182,15 +245,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
#ifdef USE_ROCM
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
#else
constexpr auto optimal_vec_size = vec_size;
#endif
elementwise_kernel_helper(
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
#endif // USE_ROCM
template <
typename func_t,
@ -237,6 +297,11 @@ static inline void launch_vectorized_kernel(
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
// that causes some numerical mismatches with uint8 on sm80 and sm90.
// TODO: Revisit this after CUDA 12.8 update.
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
const int computeCapability = p->major * 10 + p->minor;
if (computeCapability != 90 && computeCapability != 100) {
vec_size = std::min<uint16_t>(vec_size, 4);
}
if constexpr (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
}