mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
sort_out_cuda: Use custom kernels to fill index tensors (#66668)
Summary: These stable sorts currently use a combination of `at::arange`, view ops and `tensor.copy_` to fill in the initial values for the indices before calling into `CUB` to do the actual sort. This is somewhat inefficient because it requires 2 to 4 kernel launches, and the copies all use strided kernels instead of the more efficient contiguous kernels. Instead, a fairly straight-forward custom kernel is more efficient in terms of both CUDA and CPU runtime. In a simple benchmark I profiled `a.sort(stable=True, dim=1)` for different shapes and single out the kernel invocations for intitializing the index tensors (i.e. the non-`cub` kernels). Note that when the batch dim is `<128` we call `segmented_sort_pairs_by_full_sort` instead of `segmented_sort_pairs`: | shape | Master (us) | This PR (us) | |--------------|:-----------:|:------------:| | (100, 1000) | 5.000 | 2.300 | | (1000, 100) | 2.070 | 1.090 | | (100, 10000) | 87.34 | 26.47 | | (1000, 1000) | 28.63 | 20.27 | Of course for sufficiently large inputs, the overall runtime is dominated by the actual sort. But I have another motive of wanting to remove operator the calls from the middle of this kernel launch code. This change makes it easier to split the kernel code that needs to be compiled with `nvcc` into it's own file that doesn't include `Tensor.h`, similar to what I'm doing in https://github.com/pytorch/pytorch/issues/66620. Pull Request resolved: https://github.com/pytorch/pytorch/pull/66668 Reviewed By: H-Huang Differential Revision: D31693722 Pulled By: ngimel fbshipit-source-id: 5765926e4dbbc7a20d2940c098ed093b3de2204e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9ba39d2008
commit
1e2b2ee5ff
@ -253,6 +253,25 @@ __global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void fill_index_and_segment_kernel(
|
||||
int2 *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
|
||||
CUDA_KERNEL_LOOP(idx, numel) {
|
||||
auto div_mod = nsort_divider.divmod(idx);
|
||||
auto segment = static_cast<int>(div_mod.div);
|
||||
auto sort = static_cast<int>(div_mod.mod);
|
||||
data[idx] = int2{segment, sort};
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void fill_reverse_indices_kernel(
|
||||
int64_t *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
|
||||
CUDA_KERNEL_LOOP(idx, numel) {
|
||||
data[idx] = nsort_divider.mod(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename scalar_t>
|
||||
inline void segmented_sort_pairs_by_full_sort(
|
||||
int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices,
|
||||
@ -260,16 +279,21 @@ inline void segmented_sort_pairs_by_full_sort(
|
||||
) {
|
||||
int64_t segment_bits = std::max<int64_t>(1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
|
||||
|
||||
auto int_options = indices.options().dtype(kInt);
|
||||
auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options);
|
||||
indices_and_segment.select(-1, 0).copy_( // segment id
|
||||
at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort}));
|
||||
indices_and_segment.select(-1, 1).copy_( // reverse indices
|
||||
at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort}));
|
||||
const auto numel = nsort * nsegments;
|
||||
auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
|
||||
auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2));
|
||||
auto i_s_ptr = static_cast<int2 *>(indices_and_segment.get());
|
||||
|
||||
auto i_s_ptr = reinterpret_cast<int2 *>(indices_and_segment.data_ptr<int>());
|
||||
auto indices_and_segment2 = at::empty_like(indices_and_segment);
|
||||
auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());
|
||||
using namespace at::cuda::detail;
|
||||
dim3 block = CUDA_NUM_THREADS;
|
||||
dim3 grid = GET_BLOCKS(numel);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
|
||||
fill_index_and_segment_kernel<<<grid, block, 0, stream>>>(
|
||||
i_s_ptr, numel, nsort_divider);
|
||||
|
||||
auto indices_and_segment2 = cuda_allocator->allocate(nsegments * nsort * sizeof(int2));
|
||||
auto i_s_ptr2 = static_cast<int2 *>(indices_and_segment2.get());
|
||||
|
||||
at::cuda::cub::radix_sort_pairs<scalar_t, int2>(
|
||||
self_ptr, nullptr, i_s_ptr, i_s_ptr2,
|
||||
@ -286,6 +310,28 @@ inline void segmented_sort_pairs_by_full_sort(
|
||||
self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void segmented_sort_pairs(
|
||||
int64_t nsegments, int64_t nsort, int64_t n, bool descending,
|
||||
const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr) {
|
||||
const auto numel = nsort * nsegments;
|
||||
auto cuda_allocator = at::cuda::getCUDADeviceAllocator();
|
||||
auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t));
|
||||
int64_t *reverse_indices_ptr = static_cast<int64_t *>(reverse_indices.get());
|
||||
|
||||
using namespace at::cuda::detail;
|
||||
dim3 block = CUDA_NUM_THREADS;
|
||||
dim3 grid = GET_BLOCKS(numel);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
|
||||
fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
|
||||
reverse_indices_ptr, numel, nsort_divider);
|
||||
|
||||
at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
|
||||
reverse_indices_ptr, indices_ptr, n, nsegments,
|
||||
offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// We perform a segmented sort in cub with inputs that have
|
||||
@ -414,6 +460,7 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
|
||||
|
||||
int64_t numel_or_intmax = std::min(numel, static_cast<int64_t>(std::numeric_limits<int>::max()));
|
||||
int64_t nbatch = (numel_or_intmax / nsort) * nsort;
|
||||
TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort);
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
constexpr bool is_rocm = true;
|
||||
@ -434,10 +481,8 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
|
||||
segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending,
|
||||
indices, self_ptr, values_ptr, indices_ptr);
|
||||
} else {
|
||||
auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
|
||||
at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
|
||||
reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
|
||||
offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
|
||||
segmented_sort_pairs(nsegments, nsort, n, descending,
|
||||
self_ptr, values_ptr, indices_ptr);
|
||||
}
|
||||
|
||||
remaining -= n;
|
||||
|
Reference in New Issue
Block a user