diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index 3ceb3001e50b..5c08ddf59782 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -11,6 +11,7 @@ #include #include +#include namespace at { namespace native { @@ -231,6 +232,7 @@ __global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64 } +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) __global__ void fill_index_and_segment_kernel( int2 *data, int numel, at::cuda::detail::IntDivider nsort_divider) { CUDA_KERNEL_LOOP(idx, numel) { @@ -241,6 +243,7 @@ __global__ void fill_index_and_segment_kernel( } } +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) __global__ void fill_reverse_indices_kernel( int64_t *data, int numel, at::cuda::detail::IntDivider nsort_divider) { CUDA_KERNEL_LOOP(idx, numel) { @@ -248,6 +251,31 @@ __global__ void fill_reverse_indices_kernel( } } +template +inline void segmented_sort_large_segments( + const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending, + const scalar_t * self_ptr, scalar_t * values_ptr, int64_t * indices_ptr + ) { + using namespace at::cuda::detail; + auto allocator = at::cuda::getCUDADeviceAllocator(); + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(nsort); + c10::DeviceArray indices(*allocator, nsort); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + indices.get(), nsort, nsort_divider); + const int64_t *initial_indices = indices.get(); + + for (auto i: c10::irange(nsegments)){ + at::cuda::cub::radix_sort_pairs( + self_ptr, values_ptr, initial_indices, indices_ptr, + nsort, descending); + indices_ptr += nsort; + self_ptr += nsort; + values_ptr += nsort; + } +} template inline void segmented_sort_pairs_by_full_sort( @@ -340,7 +368,11 @@ void launch_stable_sort_kernel( int64_t n = std::min(remaining, nbatch); int64_t nsegments = n / nsort; - if (nsegments < 128) { + if (nsegments == 1 || nsort >= 1000000) { //rough heuristics where even a single sort occupies GPU + segmented_sort_large_segments( + nsegments, nsort, n, descending, + self_ptr, values_ptr, indices_ptr); + } else if (nsegments < 128) { segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, self_ptr, values_ptr, indices_ptr); } else { diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py index 7442a2e65bdd..19394c0809c8 100644 --- a/test/test_sort_and_select.py +++ b/test/test_sort_and_select.py @@ -130,6 +130,22 @@ class TestSortAndSelect(TestCase): self.assertIsOrdered('descending', x, res2val, res2ind, 'random with NaNs') + @onlyCUDA + def test_sort_large_slice(self, device): + # tests direct cub path + x = torch.randn(4, 1024000, device=device) + res1val, res1ind = torch.sort(x, stable=True) + torch.cuda.synchronize() + # assertIsOrdered is too slow, so just compare to cpu + res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True) + self.assertEqual(res1val, res1val_cpu.cuda()) + self.assertEqual(res1ind, res1ind_cpu.cuda()) + res1val, res1ind = torch.sort(x, descending=True, stable=True) + torch.cuda.synchronize() + res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True) + self.assertEqual(res1val, res1val_cpu.cuda()) + self.assertEqual(res1ind, res1ind_cpu.cuda()) + # FIXME: remove torch.bool from unsupported types once support is added for cub sort @dtypes(*all_types_and(torch.half, torch.bfloat16)) def test_stable_sort(self, device, dtype):