use more elements per thread for narrow dtypes (#139449)

Fix perf issue for narrow type by accessing more elements per thread

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139449
Approved by: https://github.com/Chillee, https://github.com/eqy
This commit is contained in:
Natalia Gimelshein
2024-11-04 16:43:33 +00:00
committed by PyTorch MergeBot
parent 3ca794783f
commit d3fc13a9dd
5 changed files with 67 additions and 27 deletions

View File

@ -52,13 +52,49 @@
namespace at::native {
template <typename args_t, size_t... Is>
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
if constexpr (sizeof...(Is) == 0) {
return 0;
} else {
return (sizeof(std::tuple_element_t<Is, args_t>) + ...);
}
}
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
return 16;
} else if constexpr (io_sizes < 4) {
return 8;
} else {
return 4;
}
}
template <int io_sizes>
constexpr auto io_block_work_size() {
return num_threads() * elems_per_thread<io_sizes>();
}
template <typename func_t>
constexpr auto calc_io_size(){
using traits = function_traits<func_t>;
using args_t = typename traits::ArgsTuple;
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
constexpr auto output_size = sizeof(typename traits::result_type);
return input_size + output_size;
}
template <int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
int remaining = N - block_work_size() * blockIdx.x;
constexpr auto io_size = calc_io_size<func_t>();
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
if (remaining < block_work_size()) { // if this block handles the reminder,
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
// just do a naive unrolled loop
auto input_calc = TrivialOffsetCalculator<traits::arity>();
auto output_calc = TrivialOffsetCalculator<1>();
@ -69,19 +105,21 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
decltype(input_calc),
decltype(output_calc),
memory::LoadWithoutCast,
memory::StoreWithoutCast>(
memory::StoreWithoutCast,
elems_per_thread<io_size>()>(
data, remaining, input_calc, output_calc, loader, storer);
elementwise_kernel_helper(f, policy);
} else { // if this block has a full `block_work_size` data to handle, use
// vectorized memory access
elementwise_kernel_helper(
f, memory::policies::vectorized<vec_size, array_t>(data));
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
}
}
template <
typename func_t,
typename array_t,
int elems_per_thread,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
@ -97,7 +135,7 @@ __global__ void unrolled_elementwise_kernel(
storer_t s) {
int remaining = N - block_work_size() * blockIdx.x;
auto policy = memory::policies::
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t, elems_per_thread>(
data, remaining, ic, oc, l, s);
elementwise_kernel_helper(f, policy);
}
@ -110,7 +148,8 @@ static inline void launch_vectorized_kernel(
array_t data) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
using traits = function_traits<func_t>;
int64_t grid = (N + block_work_size() - 1) / block_work_size();
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
int vec_size = memory::can_vectorize_up_to<func_t>(data);
@ -130,7 +169,7 @@ static inline void launch_vectorized_kernel(
auto output_calc = TrivialOffsetCalculator<1>();
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
unrolled_elementwise_kernel<func_t, array_t>
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
<<<grid, num_threads(), 0, stream>>>(
N, f, data, input_calc, output_calc, loader, storer);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -159,7 +198,7 @@ static inline void launch_unrolled_kernel(
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel<func_t, array_t>
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -46,18 +46,19 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;
constexpr int elems_per_thread = policy_t::tws;
int idx = blockIdx.x;
return_t results[thread_work_size()];
args_t args[thread_work_size()];
return_t results[elems_per_thread];
args_t args[elems_per_thread];
// load
policy.load(args, idx);
// compute
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}

View File

@ -57,11 +57,11 @@ struct static_unroll<func, end, end> {
template<int arg_index>
struct vectorized_load_helper {
template <typename args_t, typename policy_t>
static __device__ void apply(policy_t &self, args_t *args, int idx) {
static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) {
using arg_t = std::tuple_element_t<arg_index, args_t>;
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
// need a +1 offset to get the input
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size * idx;
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
self.load_single_arg(args_accessor, ptr);
}
@ -181,9 +181,7 @@ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint
namespace policies {
// Assumption:
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int elems_per_thread, int num_outputs=1>
struct unroll {
data_t data;
@ -192,6 +190,7 @@ struct unroll {
out_calc_t output_offset_calculator;
loader_t loader;
storer_t storer;
static constexpr int tws = elems_per_thread;
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
@ -205,11 +204,11 @@ struct unroll {
constexpr int arity = std::tuple_size_v<args_t>;
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (thread_idx >= remaining) {
return;
}
int linear_idx = thread_idx + block_work_size() * idx;
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
auto offset = input_offset_calculator.get(linear_idx);
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
thread_idx += num_threads();
@ -220,11 +219,11 @@ struct unroll {
__device__ inline void store(scalar_t *from, int idx) {
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
for (int i = 0; i < elems_per_thread; i++) {
if (thread_idx >= remaining) {
return;
}
int linear_idx = thread_idx + block_work_size() * idx;
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
int offset = output_offset_calculator.get(linear_idx)[0];
storer.store(from[i], data[0], offset);
thread_idx += num_threads();
@ -237,11 +236,12 @@ struct unroll {
// Note:
// Functions in vectorized policy does not do boundary check. It assumes the whole block
// has its job to do. So the reminders should be handled by the caller manually.
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
template <int vec_size, typename data_t, int elems_per_thread> // vec_size: number of scalars, can be 1, 2, or 4.
struct vectorized {
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
static constexpr int loop_size = thread_work_size() / vec_size;
static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size");
static constexpr int loop_size = elems_per_thread / vec_size;
static constexpr int tws = elems_per_thread;
data_t data;
@ -268,13 +268,13 @@ struct vectorized {
template<typename args_t>
__device__ inline void load(args_t *args, int idx) {
constexpr int arity = std::tuple_size_v<args_t>;
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx, elems_per_thread * num_threads());
}
template<typename scalar_t>
__device__ inline void store(scalar_t *from, int idx) {
using vec_t = aligned_vector<scalar_t, vec_size>;
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + elems_per_thread * num_threads() * idx;
vec_t *to_ = reinterpret_cast<vec_t *>(to);
int thread_idx = threadIdx.x;
#pragma unroll
@ -299,6 +299,7 @@ struct multi_outputs_unroll {
out_calc_t output_offset_calculator;
LoadWithoutCast loader;
StoreWithoutCast storer;
static constexpr int tws = thread_work_size();
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}

View File

@ -82,7 +82,7 @@ __global__ void vectorized_copy(scalar_t *dst, scalar_t *src) {
data[0] = reinterpret_cast<char *>(dst);
data[1] = reinterpret_cast<char *>(src);
int idx = blockIdx.x;
using vectorized = policies::vectorized<vec_size, array_t>;
using vectorized = policies::vectorized<vec_size, array_t, thread_work_size()>;
auto policy = vectorized(data);
scalar_t buf[thread_work_size()];
#if !defined(USE_ROCM)

View File

@ -1045,7 +1045,6 @@ class TestReductions(TestCase):
a[:, (shape[1] - 1) // 2:] = True
values, indices = a.mode(-1)
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
print(indices)
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
self.assertEqual(values, indexed)