mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "use more elements per thread for narrow dtypes (#139449)"
This reverts commit d3fc13a9dd186ceb8d1b56b0968a41686ea645cd. Reverted https://github.com/pytorch/pytorch/pull/139449 on behalf of https://github.com/ngimel due to breaks tests ([comment](https://github.com/pytorch/pytorch/pull/139449#issuecomment-2477012582))
This commit is contained in:
@ -52,49 +52,13 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
|
||||
template <typename args_t, size_t... Is>
|
||||
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
|
||||
if constexpr (sizeof...(Is) == 0) {
|
||||
return 0;
|
||||
} else {
|
||||
return (sizeof(std::tuple_element_t<Is, args_t>) + ...);
|
||||
}
|
||||
}
|
||||
|
||||
template <int io_sizes>
|
||||
constexpr auto elems_per_thread(){
|
||||
if constexpr (io_sizes == 1) {
|
||||
return 16;
|
||||
} else if constexpr (io_sizes < 4) {
|
||||
return 8;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <int io_sizes>
|
||||
constexpr auto io_block_work_size() {
|
||||
return num_threads() * elems_per_thread<io_sizes>();
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
constexpr auto calc_io_size(){
|
||||
using traits = function_traits<func_t>;
|
||||
using args_t = typename traits::ArgsTuple;
|
||||
constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence<std::tuple_size_v<args_t>>{});
|
||||
constexpr auto output_size = sizeof(typename traits::result_type);
|
||||
return input_size + output_size;
|
||||
}
|
||||
|
||||
template <int vec_size, typename func_t, typename array_t>
|
||||
C10_LAUNCH_BOUNDS_1(num_threads())
|
||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||
using traits = function_traits<func_t>;
|
||||
constexpr auto io_size = calc_io_size<func_t>();
|
||||
int remaining = N - io_block_work_size<io_size>() * blockIdx.x;
|
||||
int remaining = N - block_work_size() * blockIdx.x;
|
||||
|
||||
if (remaining < io_block_work_size<io_size>()) { // if this block handles the reminder,
|
||||
if (remaining < block_work_size()) { // if this block handles the reminder,
|
||||
// just do a naive unrolled loop
|
||||
auto input_calc = TrivialOffsetCalculator<traits::arity>();
|
||||
auto output_calc = TrivialOffsetCalculator<1>();
|
||||
@ -105,21 +69,19 @@ __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
|
||||
decltype(input_calc),
|
||||
decltype(output_calc),
|
||||
memory::LoadWithoutCast,
|
||||
memory::StoreWithoutCast,
|
||||
elems_per_thread<io_size>()>(
|
||||
memory::StoreWithoutCast>(
|
||||
data, remaining, input_calc, output_calc, loader, storer);
|
||||
elementwise_kernel_helper(f, policy);
|
||||
} else { // if this block has a full `block_work_size` data to handle, use
|
||||
// vectorized memory access
|
||||
elementwise_kernel_helper(
|
||||
f, memory::policies::vectorized<vec_size, array_t, elems_per_thread<io_size>()>(data));
|
||||
f, memory::policies::vectorized<vec_size, array_t>(data));
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename func_t,
|
||||
typename array_t,
|
||||
int elems_per_thread,
|
||||
typename inp_calc_t,
|
||||
typename out_calc_t,
|
||||
typename loader_t,
|
||||
@ -135,7 +97,7 @@ __global__ void unrolled_elementwise_kernel(
|
||||
storer_t s) {
|
||||
int remaining = N - block_work_size() * blockIdx.x;
|
||||
auto policy = memory::policies::
|
||||
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t, elems_per_thread>(
|
||||
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
|
||||
data, remaining, ic, oc, l, s);
|
||||
elementwise_kernel_helper(f, policy);
|
||||
}
|
||||
@ -148,8 +110,7 @@ static inline void launch_vectorized_kernel(
|
||||
array_t data) {
|
||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
using traits = function_traits<func_t>;
|
||||
constexpr auto io_size = calc_io_size<func_t>();
|
||||
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
|
||||
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
int vec_size = memory::can_vectorize_up_to<func_t>(data);
|
||||
|
||||
@ -169,7 +130,7 @@ static inline void launch_vectorized_kernel(
|
||||
auto output_calc = TrivialOffsetCalculator<1>();
|
||||
auto loader = memory::LoadWithoutCast();
|
||||
auto storer = memory::StoreWithoutCast();
|
||||
unrolled_elementwise_kernel<func_t, array_t, elems_per_thread<io_size>()>
|
||||
unrolled_elementwise_kernel<func_t, array_t>
|
||||
<<<grid, num_threads(), 0, stream>>>(
|
||||
N, f, data, input_calc, output_calc, loader, storer);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
@ -198,7 +159,7 @@ static inline void launch_unrolled_kernel(
|
||||
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
||||
int64_t grid = (N + block_work_size() - 1) / block_work_size();
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
|
||||
unrolled_elementwise_kernel<func_t, array_t>
|
||||
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
@ -46,19 +46,18 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
||||
using traits = function_traits<func_t>;
|
||||
using return_t = typename traits::result_type;
|
||||
using args_t = typename traits::ArgsTuple;
|
||||
constexpr int elems_per_thread = policy_t::tws;
|
||||
|
||||
int idx = blockIdx.x;
|
||||
|
||||
return_t results[elems_per_thread];
|
||||
args_t args[elems_per_thread];
|
||||
return_t results[thread_work_size()];
|
||||
args_t args[thread_work_size()];
|
||||
|
||||
// load
|
||||
policy.load(args, idx);
|
||||
|
||||
// compute
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
for (int i = 0; i < thread_work_size(); i++) {
|
||||
if (policy.check_inbounds(i)) {
|
||||
results[i] = c10::guts::apply(f, args[i]);
|
||||
}
|
||||
|
@ -57,11 +57,11 @@ struct static_unroll<func, end, end> {
|
||||
template<int arg_index>
|
||||
struct vectorized_load_helper {
|
||||
template <typename args_t, typename policy_t>
|
||||
static __device__ void apply(policy_t &self, args_t *args, int idx, int block_work_size) {
|
||||
static __device__ void apply(policy_t &self, args_t *args, int idx) {
|
||||
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
||||
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
|
||||
// need a +1 offset to get the input
|
||||
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size * idx;
|
||||
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
|
||||
auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
|
||||
self.load_single_arg(args_accessor, ptr);
|
||||
}
|
||||
@ -181,7 +181,9 @@ __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint
|
||||
|
||||
namespace policies {
|
||||
|
||||
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int elems_per_thread, int num_outputs=1>
|
||||
// Assumption:
|
||||
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
|
||||
template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
|
||||
struct unroll {
|
||||
|
||||
data_t data;
|
||||
@ -190,7 +192,6 @@ struct unroll {
|
||||
out_calc_t output_offset_calculator;
|
||||
loader_t loader;
|
||||
storer_t storer;
|
||||
static constexpr int tws = elems_per_thread;
|
||||
|
||||
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
|
||||
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
|
||||
@ -204,11 +205,11 @@ struct unroll {
|
||||
constexpr int arity = std::tuple_size_v<args_t>;
|
||||
int thread_idx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
for (int i = 0; i < thread_work_size(); i++) {
|
||||
if (thread_idx >= remaining) {
|
||||
return;
|
||||
}
|
||||
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
|
||||
int linear_idx = thread_idx + block_work_size() * idx;
|
||||
auto offset = input_offset_calculator.get(linear_idx);
|
||||
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
|
||||
thread_idx += num_threads();
|
||||
@ -219,11 +220,11 @@ struct unroll {
|
||||
__device__ inline void store(scalar_t *from, int idx) {
|
||||
int thread_idx = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_thread; i++) {
|
||||
for (int i = 0; i < thread_work_size(); i++) {
|
||||
if (thread_idx >= remaining) {
|
||||
return;
|
||||
}
|
||||
int linear_idx = thread_idx + elems_per_thread * num_threads() * idx;
|
||||
int linear_idx = thread_idx + block_work_size() * idx;
|
||||
int offset = output_offset_calculator.get(linear_idx)[0];
|
||||
storer.store(from[i], data[0], offset);
|
||||
thread_idx += num_threads();
|
||||
@ -236,12 +237,11 @@ struct unroll {
|
||||
// Note:
|
||||
// Functions in vectorized policy does not do boundary check. It assumes the whole block
|
||||
// has its job to do. So the reminders should be handled by the caller manually.
|
||||
template <int vec_size, typename data_t, int elems_per_thread> // vec_size: number of scalars, can be 1, 2, or 4.
|
||||
template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
|
||||
struct vectorized {
|
||||
|
||||
static_assert(elems_per_thread % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
||||
static constexpr int loop_size = elems_per_thread / vec_size;
|
||||
static constexpr int tws = elems_per_thread;
|
||||
static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
|
||||
static constexpr int loop_size = thread_work_size() / vec_size;
|
||||
|
||||
data_t data;
|
||||
|
||||
@ -268,13 +268,13 @@ struct vectorized {
|
||||
template<typename args_t>
|
||||
__device__ inline void load(args_t *args, int idx) {
|
||||
constexpr int arity = std::tuple_size_v<args_t>;
|
||||
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx, elems_per_thread * num_threads());
|
||||
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__device__ inline void store(scalar_t *from, int idx) {
|
||||
using vec_t = aligned_vector<scalar_t, vec_size>;
|
||||
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + elems_per_thread * num_threads() * idx;
|
||||
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
|
||||
vec_t *to_ = reinterpret_cast<vec_t *>(to);
|
||||
int thread_idx = threadIdx.x;
|
||||
#pragma unroll
|
||||
@ -299,7 +299,6 @@ struct multi_outputs_unroll {
|
||||
out_calc_t output_offset_calculator;
|
||||
LoadWithoutCast loader;
|
||||
StoreWithoutCast storer;
|
||||
static constexpr int tws = thread_work_size();
|
||||
|
||||
__device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
|
||||
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
|
||||
|
@ -82,7 +82,7 @@ __global__ void vectorized_copy(scalar_t *dst, scalar_t *src) {
|
||||
data[0] = reinterpret_cast<char *>(dst);
|
||||
data[1] = reinterpret_cast<char *>(src);
|
||||
int idx = blockIdx.x;
|
||||
using vectorized = policies::vectorized<vec_size, array_t, thread_work_size()>;
|
||||
using vectorized = policies::vectorized<vec_size, array_t>;
|
||||
auto policy = vectorized(data);
|
||||
scalar_t buf[thread_work_size()];
|
||||
#if !defined(USE_ROCM)
|
||||
|
@ -1045,6 +1045,7 @@ class TestReductions(TestCase):
|
||||
a[:, (shape[1] - 1) // 2:] = True
|
||||
values, indices = a.mode(-1)
|
||||
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
|
||||
print(indices)
|
||||
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
|
||||
self.assertEqual(values, indexed)
|
||||
|
||||
|
Reference in New Issue
Block a user