mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support torch.bool in torch.sort + CUDA (#139409)
Summary: This might be out-dated, so I'm adding it back and see if we pass all the tests. I'm pretty sure cuda12 is ok. Test Plan: CI Differential Revision: D65282650 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139409 Approved by: https://github.com/zou3519, https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
06f619d999
commit
e7cf7d00be
@ -193,8 +193,7 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertEqual(res1val, res1val_cpu.cuda())
|
||||
self.assertEqual(res1ind, res1ind_cpu.cuda())
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_stable_sort(self, device, dtype):
|
||||
sizes = (100, 1000, 10000)
|
||||
for ncopies in sizes:
|
||||
@ -323,8 +322,7 @@ class TestSortAndSelect(TestCase):
|
||||
self.assertEqual(indices, indices_cont)
|
||||
self.assertEqual(values, values_cont)
|
||||
|
||||
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
|
||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
|
||||
def test_stable_sort_against_numpy(self, device, dtype):
|
||||
if dtype in floating_types_and(torch.float16, torch.bfloat16):
|
||||
inf = float("inf")
|
||||
|
Reference in New Issue
Block a user