mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Speedup segmented sort with large nsort
Follow up to gh-77100 Instead of calling `at::arange`, I repurpose the existing kernels to achieve the same effect. I've also fixed the 2d case bug where the pointer was advanced by `n` which equals `nsegment * nsort` after only processing `nsort` elements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77188 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
99339fddd9
commit
d9351ed520
@ -130,6 +130,22 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertIsOrdered('descending', x, res2val, res2ind,
|
||||
'random with NaNs')
|
||||
|
||||
@onlyCUDA
|
||||
def test_sort_large_slice(self, device):
|
||||
# tests direct cub path
|
||||
x = torch.randn(4, 1024000, device=device)
|
||||
res1val, res1ind = torch.sort(x, stable=True)
|
||||
torch.cuda.synchronize()
|
||||
# assertIsOrdered is too slow, so just compare to cpu
|
||||
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True)
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
res1val, res1ind = torch.sort(x, descending=True, stable=True)
|
||||
torch.cuda.synchronize()
|
||||
res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True)
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
def test_stable_sort(self, device, dtype):
|
||||
|
||||
Reference in New Issue
Block a user