Speedup segmented sort with large nsort

Follow up to gh-77100

Instead of calling `at::arange`, I repurpose the existing kernels to
achieve the same effect. I've also fixed the 2d case bug where
the pointer was advanced by `n` which equals `nsegment * nsort` after
only processing `nsort` elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188

Approved by: https://github.com/ngimel
This commit is contained in:
Peter Bell
2022-05-11 14:03:02 +01:00
committed by PyTorch MergeBot
parent 99339fddd9
commit d9351ed520
2 changed files with 49 additions and 1 deletions

View File

@ -11,6 +11,7 @@
#include <ATen/native/cuda/SortingCommon.cuh>
#include <limits>
#include <c10/core/DeviceArray.h>
namespace at { namespace native {
@ -231,6 +232,7 @@ __global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64
}
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
__global__ void fill_index_and_segment_kernel(
int2 *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
CUDA_KERNEL_LOOP(idx, numel) {
@ -241,6 +243,7 @@ __global__ void fill_index_and_segment_kernel(
}
}
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
__global__ void fill_reverse_indices_kernel(
int64_t *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
CUDA_KERNEL_LOOP(idx, numel) {
@ -248,6 +251,31 @@ __global__ void fill_reverse_indices_kernel(
}
}
template<typename scalar_t>
inline void segmented_sort_large_segments(
const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending,
const scalar_t * self_ptr, scalar_t * values_ptr, int64_t * indices_ptr
) {
using namespace at::cuda::detail;
auto allocator = at::cuda::getCUDADeviceAllocator();
auto stream = at::cuda::getCurrentCUDAStream();
dim3 block = CUDA_NUM_THREADS;
dim3 grid = GET_BLOCKS(nsort);
c10::DeviceArray<int64_t> indices(*allocator, nsort);
at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
indices.get(), nsort, nsort_divider);
const int64_t *initial_indices = indices.get();
for (auto i: c10::irange(nsegments)){
at::cuda::cub::radix_sort_pairs<scalar_t, int64_t>(
self_ptr, values_ptr, initial_indices, indices_ptr,
nsort, descending);
indices_ptr += nsort;
self_ptr += nsort;
values_ptr += nsort;
}
}
template<typename scalar_t>
inline void segmented_sort_pairs_by_full_sort(
@ -340,7 +368,11 @@ void launch_stable_sort_kernel(
int64_t n = std::min(remaining, nbatch);
int64_t nsegments = n / nsort;
if (nsegments < 128) {
if (nsegments == 1 || nsort >= 1000000) { //rough heuristics where even a single sort occupies GPU
segmented_sort_large_segments(
nsegments, nsort, n, descending,
self_ptr, values_ptr, indices_ptr);
} else if (nsegments < 128) {
segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending,
self_ptr, values_ptr, indices_ptr);
} else {

View File

@ -130,6 +130,22 @@ class TestSortAndSelect(TestCase):
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')
@onlyCUDA
def test_sort_large_slice(self, device):
# tests direct cub path
x = torch.randn(4, 1024000, device=device)
res1val, res1ind = torch.sort(x, stable=True)
torch.cuda.synchronize()
# assertIsOrdered is too slow, so just compare to cpu
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True)
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
res1val, res1ind = torch.sort(x, descending=True, stable=True)
torch.cuda.synchronize()
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True)
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
def test_stable_sort(self, device, dtype):