mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ATen][CUDA] Optimize 128 bit vectorization (#148320)
Fixes #147376. As per request: https://github.com/pytorch/pytorch/pull/145746#pullrequestreview-2642118301 This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148320 Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
3baa85cfad
commit
72337bdcf2
@ -158,6 +158,69 @@ constexpr auto calc_io_size(){
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel
|
||||
// into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be
|
||||
// used on sm_90 and sm_100 exclusively.
|
||||
template <int vec_size, typename func_t, typename array_t>
|
||||
C10_LAUNCH_BOUNDS_1(num_threads())
|
||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||
if constexpr (vec_size == 8) {
|
||||
#if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
|
||||
using traits = function_traits<func_t>;
|
||||
constexpr auto io_size = calc_io_size<func_t>();
|
||||
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
|
||||
|
||||
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
|
||||
// just do a naive unrolled loop
|
||||
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
||||
auto output_calc = TrivialOffsetCalculator<1>();
|
||||
auto loader = memory::LoadWithoutCast();
|
||||
auto storer = memory::StoreWithoutCast();
|
||||
auto policy = memory::policies::unroll<
|
||||
array_t,
|
||||
decltype(input_calc),
|
||||
decltype(output_calc),
|
||||
memory::LoadWithoutCast,
|
||||
memory::StoreWithoutCast,
|
||||
elems_per_thread<io_size>()>(
|
||||
data, remaining, input_calc, output_calc, loader, storer);
|
||||
elementwise_kernel_helper(f, policy);
|
||||
} else { // if this block has a full `block_work_size` data to handle, use
|
||||
// vectorized memory access
|
||||
elementwise_kernel_helper(
|
||||
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
|
||||
}
|
||||
#endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000
|
||||
} else {
|
||||
using traits = function_traits<func_t>;
|
||||
constexpr auto io_size = calc_io_size<func_t>();
|
||||
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
|
||||
|
||||
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
|
||||
// just do a naive unrolled loop
|
||||
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
||||
auto output_calc = TrivialOffsetCalculator<1>();
|
||||
auto loader = memory::LoadWithoutCast();
|
||||
auto storer = memory::StoreWithoutCast();
|
||||
auto policy = memory::policies::unroll<
|
||||
array_t,
|
||||
decltype(input_calc),
|
||||
decltype(output_calc),
|
||||
memory::LoadWithoutCast,
|
||||
memory::StoreWithoutCast,
|
||||
elems_per_thread<io_size>()>(
|
||||
data, remaining, input_calc, output_calc, loader, storer);
|
||||
elementwise_kernel_helper(f, policy);
|
||||
} else { // if this block has a full `block_work_size` data to handle, use
|
||||
// vectorized memory access
|
||||
elementwise_kernel_helper(
|
||||
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else // USE_ROCM
|
||||
template <int vec_size, typename func_t, typename array_t>
|
||||
C10_LAUNCH_BOUNDS_1(num_threads())
|
||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||
@ -182,15 +245,12 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||
elementwise_kernel_helper(f, policy);
|
||||
} else { // if this block has a full `block_work_size` data to handle, use
|
||||
// vectorized memory access
|
||||
#ifdef USE_ROCM
|
||||
constexpr auto optimal_vec_size = calc_optimal_vec_size<vec_size, io_size>();
|
||||
#else
|
||||
constexpr auto optimal_vec_size = vec_size;
|
||||
#endif
|
||||
elementwise_kernel_helper(
|
||||
f, memory::policies::vectorized<optimal_vec_size, array_t, elems_per_thread<io_size>()>(data));
|
||||
}
|
||||
}
|
||||
#endif // USE_ROCM
|
||||
|
||||
template <
|
||||
typename func_t,
|
||||
@ -237,6 +297,11 @@ static inline void launch_vectorized_kernel(
|
||||
// Here we purposely omit vec8 for 1-byte data because of a bug in NVCC
|
||||
// that causes some numerical mismatches with uint8 on sm80 and sm90.
|
||||
// TODO: Revisit this after CUDA 12.8 update.
|
||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index());
|
||||
const int computeCapability = p->major * 10 + p->minor;
|
||||
if (computeCapability != 90 && computeCapability != 100) {
|
||||
vec_size = std::min<uint16_t>(vec_size, 4);
|
||||
}
|
||||
if constexpr (sizeof(cpp_type) < 2) {
|
||||
vec_size = std::min<uint16_t>(vec_size, 4);
|
||||
}
|
||||
|
Reference in New Issue
Block a user