diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index 5ba6e89714a5..7ced1eb71cd9 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -253,6 +253,25 @@ __global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64 } } + +__global__ void fill_index_and_segment_kernel( + int2 *data, int numel, at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + auto div_mod = nsort_divider.divmod(idx); + auto segment = static_cast(div_mod.div); + auto sort = static_cast(div_mod.mod); + data[idx] = int2{segment, sort}; + } +} + +__global__ void fill_reverse_indices_kernel( + int64_t *data, int numel, at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + data[idx] = nsort_divider.mod(idx); + } +} + + template inline void segmented_sort_pairs_by_full_sort( int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices, @@ -260,16 +279,21 @@ inline void segmented_sort_pairs_by_full_sort( ) { int64_t segment_bits = std::max(1L, static_cast(std::ceil(std::log2(nsegments)))); - auto int_options = indices.options().dtype(kInt); - auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options); - indices_and_segment.select(-1, 0).copy_( // segment id - at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort})); - indices_and_segment.select(-1, 1).copy_( // reverse indices - at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort})); + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2)); + auto i_s_ptr = static_cast(indices_and_segment.get()); - auto i_s_ptr = reinterpret_cast(indices_and_segment.data_ptr()); - auto indices_and_segment2 = at::empty_like(indices_and_segment); - auto i_s_ptr2 = reinterpret_cast(indices_and_segment2.data_ptr()); + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_index_and_segment_kernel<<>>( + i_s_ptr, numel, nsort_divider); + + auto indices_and_segment2 = cuda_allocator->allocate(nsegments * nsort * sizeof(int2)); + auto i_s_ptr2 = static_cast(indices_and_segment2.get()); at::cuda::cub::radix_sort_pairs( self_ptr, nullptr, i_s_ptr, i_s_ptr2, @@ -286,6 +310,28 @@ inline void segmented_sort_pairs_by_full_sort( self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); } +template +void segmented_sort_pairs( + int64_t nsegments, int64_t nsort, int64_t n, bool descending, + const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr) { + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t)); + int64_t *reverse_indices_ptr = static_cast(reverse_indices.get()); + + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + reverse_indices_ptr, numel, nsort_divider); + + at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, + reverse_indices_ptr, indices_ptr, n, nsegments, + offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); +} + } // namespace // We perform a segmented sort in cub with inputs that have @@ -414,6 +460,7 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt int64_t numel_or_intmax = std::min(numel, static_cast(std::numeric_limits::max())); int64_t nbatch = (numel_or_intmax / nsort) * nsort; + TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); #if defined(USE_ROCM) constexpr bool is_rocm = true; @@ -434,10 +481,8 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, indices, self_ptr, values_ptr, indices_ptr); } else { - auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous(); - at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, - reverse_indices.data_ptr(), indices_ptr, n, nsegments, - offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); + segmented_sort_pairs(nsegments, nsort, n, descending, + self_ptr, values_ptr, indices_ptr); } remaining -= n;