Support torch.bool in torch.sort + CUDA (#139409)

Summary: This might be out-dated, so I'm adding it back and see if we pass all the tests. I'm pretty sure cuda12 is ok.

Test Plan: CI

Differential Revision: D65282650

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139409
Approved by: https://github.com/zou3519, https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Xiaodong Wang
2024-11-06 00:02:52 +00:00
committed by PyTorch MergeBot
parent 06f619d999
commit e7cf7d00be
3 changed files with 20 additions and 11 deletions

View File

@ -63,9 +63,6 @@ void sort_cuda_kernel(
"The dimension being sorted can not have more than INT_MAX elements.");
const auto self_dtype = self.dtype();
// FIXME: remove this check once cub sort supports bool
TORCH_CHECK(self_dtype != ScalarType::Bool,
"Sort currently does not support bool dtype on CUDA.");
TORCH_CHECK(self_dtype != ScalarType::ComplexFloat && self_dtype != ScalarType::ComplexDouble,
"Sort currently does not support complex dtypes on CUDA.");

View File

@ -193,8 +193,7 @@ class TestSortAndSelect(TestCase):
self.assertEqual(res1val, res1val_cpu.cuda())
self.assertEqual(res1ind, res1ind_cpu.cuda())
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
def test_stable_sort(self, device, dtype):
sizes = (100, 1000, 10000)
for ncopies in sizes:
@ -323,8 +322,7 @@ class TestSortAndSelect(TestCase):
self.assertEqual(indices, indices_cont)
self.assertEqual(values, values_cont)
# FIXME: remove torch.bool from unsupported types once support is added for cub sort
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypes(*all_types_and(torch.bool, torch.half, torch.bfloat16))
def test_stable_sort_against_numpy(self, device, dtype):
if dtype in floating_types_and(torch.float16, torch.bfloat16):
inf = float("inf")

View File

@ -3320,6 +3320,9 @@ def sample_inputs_sort(op_info, device, dtype, requires_grad, **kwargs):
flag = [True, False]
for dim, descending, stable in product(dims, flag, flag):
# default schema without stable sort
if not (dtype == torch.bool and torch.device(device).type == 'cuda'):
# bool and cuda requires stable sort for stable results, at least
# for the return index
yield SampleInput(small_3d_unique(), dim, descending)
# schema with stable sort, no CUDA support yet
if torch.device(device).type == 'cpu':
@ -18477,11 +18480,13 @@ op_db: List[OpInfo] = [
)),
OpInfo('sort',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_sort,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
dtypes=[torch.bool], device_type='cuda'),
)),
OpInfo('unique',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64),
@ -19506,12 +19511,14 @@ op_db: List[OpInfo] = [
sample_inputs_func=sample_inputs_unfold),
OpInfo('msort',
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
check_batched_gradgrad=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_msort,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_non_standard_bool_values',
dtypes=[torch.bool], device_type='cuda'),
)),
OpInfo('movedim',
aliases=('moveaxis',),
@ -21324,7 +21331,7 @@ op_db: List[OpInfo] = [
OpInfo(
"argsort",
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_sort,
supports_out=False,
supports_autograd=False,
@ -21335,6 +21342,13 @@ op_db: List[OpInfo] = [
"test_variant_consistency_jit",
dtypes=(torch.float32,),
),
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_non_standard_bool_values",
dtypes=[torch.bool],
device_type='cuda',
),
),
),
OpInfo(