Support torch.bool in torch.sort + CUDA (#139409)

Summary: This might be out-dated, so I'm adding it back and see if we pass all the tests. I'm pretty sure cuda12 is ok.

Test Plan: CI

Differential Revision: D65282650

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139409
Approved by: https://github.com/zou3519, https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Xiaodong Wang
2024-11-06 00:02:52 +00:00
committed by PyTorch MergeBot
parent 06f619d999
commit e7cf7d00be
3 changed files with 20 additions and 11 deletions

View File

@ -193,8 +193,7 @@ class TestSortAndSelect(TestCase):
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
def test_stable_sort(self, device, dtype):
sizes = (100, 1000, 10000)
for ncopies in sizes:
@ -323,8 +322,7 @@ class TestSortAndSelect(TestCase):
self.assertEqual(indices, indices_cont)
self.assertEqual(values, values_cont)
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
def test_stable_sort_against_numpy(self, device, dtype):
if dtype in floating_types_and(torch.float16, torch.bfloat16):
inf = float("inf")