mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Speedup segmented sort with large nsort
Follow up to gh-77100 Instead of calling `at::arange`, I repurpose the existing kernels to achieve the same effect. I've also fixed the 2d case bug where the pointer was advanced by `n` which equals `nsegment * nsort` after only processing `nsort` elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
99339fddd9
commit
d9351ed520
@ -11,6 +11,7 @@
|
||||
#include <ATen/native/cuda/SortingCommon.cuh>
|
||||
|
||||
#include <limits>
|
||||
#include <c10/core/DeviceArray.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
@ -231,6 +232,7 @@ __global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64
|
||||
}
|
||||
|
||||
|
||||
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
|
||||
__global__ void fill_index_and_segment_kernel(
|
||||
int2 *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
|
||||
CUDA_KERNEL_LOOP(idx, numel) {
|
||||
@ -241,6 +243,7 @@ __global__ void fill_index_and_segment_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS)
|
||||
__global__ void fill_reverse_indices_kernel(
|
||||
int64_t *data, int numel, at::cuda::detail::IntDivider<uint32_t> nsort_divider) {
|
||||
CUDA_KERNEL_LOOP(idx, numel) {
|
||||
@ -248,6 +251,31 @@ __global__ void fill_reverse_indices_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
inline void segmented_sort_large_segments(
|
||||
const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending,
|
||||
const scalar_t * self_ptr, scalar_t * values_ptr, int64_t * indices_ptr
|
||||
) {
|
||||
using namespace at::cuda::detail;
|
||||
auto allocator = at::cuda::getCUDADeviceAllocator();
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 block = CUDA_NUM_THREADS;
|
||||
dim3 grid = GET_BLOCKS(nsort);
|
||||
c10::DeviceArray<int64_t> indices(*allocator, nsort);
|
||||
at::cuda::detail::IntDivider<uint32_t> nsort_divider(nsort);
|
||||
fill_reverse_indices_kernel<<<grid, block, 0, stream>>>(
|
||||
indices.get(), nsort, nsort_divider);
|
||||
const int64_t *initial_indices = indices.get();
|
||||
|
||||
for (auto i: c10::irange(nsegments)){
|
||||
at::cuda::cub::radix_sort_pairs<scalar_t, int64_t>(
|
||||
self_ptr, values_ptr, initial_indices, indices_ptr,
|
||||
nsort, descending);
|
||||
indices_ptr += nsort;
|
||||
self_ptr += nsort;
|
||||
values_ptr += nsort;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
inline void segmented_sort_pairs_by_full_sort(
|
||||
@ -340,7 +368,11 @@ void launch_stable_sort_kernel(
|
||||
int64_t n = std::min(remaining, nbatch);
|
||||
int64_t nsegments = n / nsort;
|
||||
|
||||
if (nsegments < 128) {
|
||||
if (nsegments == 1 || nsort >= 1000000) { //rough heuristics where even a single sort occupies GPU
|
||||
segmented_sort_large_segments(
|
||||
nsegments, nsort, n, descending,
|
||||
self_ptr, values_ptr, indices_ptr);
|
||||
} else if (nsegments < 128) {
|
||||
segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending,
|
||||
self_ptr, values_ptr, indices_ptr);
|
||||
} else {
|
||||
|
@ -130,6 +130,22 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertIsOrdered('descending', x, res2val, res2ind,
|
||||
'random with NaNs')
|
||||
|
||||
@onlyCUDA
|
||||
def test_sort_large_slice(self, device):
|
||||
# tests direct cub path
|
||||
x = torch.randn(4, 1024000, device=device)
|
||||
res1val, res1ind = torch.sort(x, stable=True)
|
||||
torch.cuda.synchronize()
|
||||
# assertIsOrdered is too slow, so just compare to cpu
|
||||
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True)
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
res1val, res1ind = torch.sort(x, descending=True, stable=True)
|
||||
torch.cuda.synchronize()
|
||||
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True)
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
def test_stable_sort(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user