Speedup segmented sort with large nsort

Follow up to gh-77100

Instead of calling `at::arange`, I repurpose the existing kernels to
achieve the same effect. I've also fixed the 2d case bug where
the pointer was advanced by `n` which equals `nsegment * nsort` after
only processing `nsort` elements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188

Approved by: https://github.com/ngimel
This commit is contained in:
Peter Bell
2022-05-11 14:03:02 +01:00
committed by PyTorch MergeBot
parent 99339fddd9
commit d9351ed520
2 changed files with 49 additions and 1 deletions

View File

@ -130,6 +130,22 @@ class TestSortAndSelect(TestCase):
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')
@onlyCUDA
def test_sort_large_slice(self, device):
# tests direct cub path
x = torch.randn(4, 1024000, device=device)
res1val, res1ind = torch.sort(x, stable=True)
torch.cuda.synchronize()
# assertIsOrdered is too slow, so just compare to cpu
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True)
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
res1val, res1ind = torch.sort(x, descending=True, stable=True)
torch.cuda.synchronize()
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True)
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
def test_stable_sort(self, device, dtype):